Source file src/crypto/tls/handshake_messages_test.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"math/rand"
    10  	"reflect"
    11  	"strings"
    12  	"testing"
    13  	"testing/quick"
    14  	"time"
    15  )
    16  
    17  var tests = []any{
    18  	&clientHelloMsg{},
    19  	&serverHelloMsg{},
    20  	&finishedMsg{},
    21  
    22  	&certificateMsg{},
    23  	&certificateRequestMsg{},
    24  	&certificateVerifyMsg{
    25  		hasSignatureAlgorithm: true,
    26  	},
    27  	&certificateStatusMsg{},
    28  	&clientKeyExchangeMsg{},
    29  	&newSessionTicketMsg{},
    30  	&sessionState{},
    31  	&sessionStateTLS13{},
    32  	&encryptedExtensionsMsg{},
    33  	&endOfEarlyDataMsg{},
    34  	&keyUpdateMsg{},
    35  	&newSessionTicketMsgTLS13{},
    36  	&certificateRequestMsgTLS13{},
    37  	&certificateMsgTLS13{},
    38  }
    39  
    40  func TestMarshalUnmarshal(t *testing.T) {
    41  	rand := rand.New(rand.NewSource(time.Now().UnixNano()))
    42  
    43  	for i, iface := range tests {
    44  		ty := reflect.ValueOf(iface).Type()
    45  
    46  		n := 100
    47  		if testing.Short() {
    48  			n = 5
    49  		}
    50  		for j := 0; j < n; j++ {
    51  			v, ok := quick.Value(ty, rand)
    52  			if !ok {
    53  				t.Errorf("#%d: failed to create value", i)
    54  				break
    55  			}
    56  
    57  			m1 := v.Interface().(handshakeMessage)
    58  			marshaled := m1.marshal()
    59  			m2 := iface.(handshakeMessage)
    60  			if !m2.unmarshal(marshaled) {
    61  				t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
    62  				break
    63  			}
    64  			m2.marshal() // to fill any marshal cache in the message
    65  
    66  			if !reflect.DeepEqual(m1, m2) {
    67  				t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
    68  				break
    69  			}
    70  
    71  			if i >= 3 {
    72  				// The first three message types (ClientHello,
    73  				// ServerHello and Finished) are allowed to
    74  				// have parsable prefixes because the extension
    75  				// data is optional and the length of the
    76  				// Finished varies across versions.
    77  				for j := 0; j < len(marshaled); j++ {
    78  					if m2.unmarshal(marshaled[0:j]) {
    79  						t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
    80  						break
    81  					}
    82  				}
    83  			}
    84  		}
    85  	}
    86  }
    87  
    88  func TestFuzz(t *testing.T) {
    89  	rand := rand.New(rand.NewSource(0))
    90  	for _, iface := range tests {
    91  		m := iface.(handshakeMessage)
    92  
    93  		for j := 0; j < 1000; j++ {
    94  			len := rand.Intn(100)
    95  			bytes := randomBytes(len, rand)
    96  			// This just looks for crashes due to bounds errors etc.
    97  			m.unmarshal(bytes)
    98  		}
    99  	}
   100  }
   101  
   102  func randomBytes(n int, rand *rand.Rand) []byte {
   103  	r := make([]byte, n)
   104  	if _, err := rand.Read(r); err != nil {
   105  		panic("rand.Read failed: " + err.Error())
   106  	}
   107  	return r
   108  }
   109  
   110  func randomString(n int, rand *rand.Rand) string {
   111  	b := randomBytes(n, rand)
   112  	return string(b)
   113  }
   114  
   115  func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   116  	m := &clientHelloMsg{}
   117  	m.vers = uint16(rand.Intn(65536))
   118  	m.random = randomBytes(32, rand)
   119  	m.sessionId = randomBytes(rand.Intn(32), rand)
   120  	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
   121  	for i := 0; i < len(m.cipherSuites); i++ {
   122  		cs := uint16(rand.Int31())
   123  		if cs == scsvRenegotiation {
   124  			cs += 1
   125  		}
   126  		m.cipherSuites[i] = cs
   127  	}
   128  	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
   129  	if rand.Intn(10) > 5 {
   130  		m.serverName = randomString(rand.Intn(255), rand)
   131  		for strings.HasSuffix(m.serverName, ".") {
   132  			m.serverName = m.serverName[:len(m.serverName)-1]
   133  		}
   134  	}
   135  	m.ocspStapling = rand.Intn(10) > 5
   136  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   137  	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
   138  	for i := range m.supportedCurves {
   139  		m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
   140  	}
   141  	if rand.Intn(10) > 5 {
   142  		m.ticketSupported = true
   143  		if rand.Intn(10) > 5 {
   144  			m.sessionTicket = randomBytes(rand.Intn(300), rand)
   145  		} else {
   146  			m.sessionTicket = make([]byte, 0)
   147  		}
   148  	}
   149  	if rand.Intn(10) > 5 {
   150  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
   151  	}
   152  	if rand.Intn(10) > 5 {
   153  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms
   154  	}
   155  	for i := 0; i < rand.Intn(5); i++ {
   156  		m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
   157  	}
   158  	if rand.Intn(10) > 5 {
   159  		m.scts = true
   160  	}
   161  	if rand.Intn(10) > 5 {
   162  		m.secureRenegotiationSupported = true
   163  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
   164  	}
   165  	for i := 0; i < rand.Intn(5); i++ {
   166  		m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
   167  	}
   168  	if rand.Intn(10) > 5 {
   169  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
   170  	}
   171  	for i := 0; i < rand.Intn(5); i++ {
   172  		var ks keyShare
   173  		ks.group = CurveID(rand.Intn(30000) + 1)
   174  		ks.data = randomBytes(rand.Intn(200)+1, rand)
   175  		m.keyShares = append(m.keyShares, ks)
   176  	}
   177  	switch rand.Intn(3) {
   178  	case 1:
   179  		m.pskModes = []uint8{pskModeDHE}
   180  	case 2:
   181  		m.pskModes = []uint8{pskModeDHE, pskModePlain}
   182  	}
   183  	for i := 0; i < rand.Intn(5); i++ {
   184  		var psk pskIdentity
   185  		psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
   186  		psk.label = randomBytes(rand.Intn(500)+1, rand)
   187  		m.pskIdentities = append(m.pskIdentities, psk)
   188  		m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
   189  	}
   190  	if rand.Intn(10) > 5 {
   191  		m.earlyData = true
   192  	}
   193  
   194  	return reflect.ValueOf(m)
   195  }
   196  
   197  func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   198  	m := &serverHelloMsg{}
   199  	m.vers = uint16(rand.Intn(65536))
   200  	m.random = randomBytes(32, rand)
   201  	m.sessionId = randomBytes(rand.Intn(32), rand)
   202  	m.cipherSuite = uint16(rand.Int31())
   203  	m.compressionMethod = uint8(rand.Intn(256))
   204  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   205  
   206  	if rand.Intn(10) > 5 {
   207  		m.ocspStapling = true
   208  	}
   209  	if rand.Intn(10) > 5 {
   210  		m.ticketSupported = true
   211  	}
   212  	if rand.Intn(10) > 5 {
   213  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   214  	}
   215  
   216  	for i := 0; i < rand.Intn(4); i++ {
   217  		m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
   218  	}
   219  
   220  	if rand.Intn(10) > 5 {
   221  		m.secureRenegotiationSupported = true
   222  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
   223  	}
   224  	if rand.Intn(10) > 5 {
   225  		m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
   226  	}
   227  	if rand.Intn(10) > 5 {
   228  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
   229  	}
   230  	if rand.Intn(10) > 5 {
   231  		for i := 0; i < rand.Intn(5); i++ {
   232  			m.serverShare.group = CurveID(rand.Intn(30000) + 1)
   233  			m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
   234  		}
   235  	} else if rand.Intn(10) > 5 {
   236  		m.selectedGroup = CurveID(rand.Intn(30000) + 1)
   237  	}
   238  	if rand.Intn(10) > 5 {
   239  		m.selectedIdentityPresent = true
   240  		m.selectedIdentity = uint16(rand.Intn(0xffff))
   241  	}
   242  
   243  	return reflect.ValueOf(m)
   244  }
   245  
   246  func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   247  	m := &encryptedExtensionsMsg{}
   248  
   249  	if rand.Intn(10) > 5 {
   250  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   251  	}
   252  
   253  	return reflect.ValueOf(m)
   254  }
   255  
   256  func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   257  	m := &certificateMsg{}
   258  	numCerts := rand.Intn(20)
   259  	m.certificates = make([][]byte, numCerts)
   260  	for i := 0; i < numCerts; i++ {
   261  		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   262  	}
   263  	return reflect.ValueOf(m)
   264  }
   265  
   266  func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   267  	m := &certificateRequestMsg{}
   268  	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
   269  	for i := 0; i < rand.Intn(100); i++ {
   270  		m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
   271  	}
   272  	return reflect.ValueOf(m)
   273  }
   274  
   275  func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   276  	m := &certificateVerifyMsg{}
   277  	m.hasSignatureAlgorithm = true
   278  	m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
   279  	m.signature = randomBytes(rand.Intn(15)+1, rand)
   280  	return reflect.ValueOf(m)
   281  }
   282  
   283  func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   284  	m := &certificateStatusMsg{}
   285  	m.response = randomBytes(rand.Intn(10)+1, rand)
   286  	return reflect.ValueOf(m)
   287  }
   288  
   289  func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   290  	m := &clientKeyExchangeMsg{}
   291  	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
   292  	return reflect.ValueOf(m)
   293  }
   294  
   295  func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   296  	m := &finishedMsg{}
   297  	m.verifyData = randomBytes(12, rand)
   298  	return reflect.ValueOf(m)
   299  }
   300  
   301  func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   302  	m := &newSessionTicketMsg{}
   303  	m.ticket = randomBytes(rand.Intn(4), rand)
   304  	return reflect.ValueOf(m)
   305  }
   306  
   307  func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
   308  	s := &sessionState{}
   309  	s.vers = uint16(rand.Intn(10000))
   310  	s.cipherSuite = uint16(rand.Intn(10000))
   311  	s.masterSecret = randomBytes(rand.Intn(100)+1, rand)
   312  	s.createdAt = uint64(rand.Int63())
   313  	for i := 0; i < rand.Intn(20); i++ {
   314  		s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand))
   315  	}
   316  	return reflect.ValueOf(s)
   317  }
   318  
   319  func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   320  	s := &sessionStateTLS13{}
   321  	s.cipherSuite = uint16(rand.Intn(10000))
   322  	s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand)
   323  	s.createdAt = uint64(rand.Int63())
   324  	for i := 0; i < rand.Intn(2)+1; i++ {
   325  		s.certificate.Certificate = append(
   326  			s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
   327  	}
   328  	if rand.Intn(10) > 5 {
   329  		s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
   330  	}
   331  	if rand.Intn(10) > 5 {
   332  		for i := 0; i < rand.Intn(2)+1; i++ {
   333  			s.certificate.SignedCertificateTimestamps = append(
   334  				s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
   335  		}
   336  	}
   337  	return reflect.ValueOf(s)
   338  }
   339  
   340  func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   341  	m := &endOfEarlyDataMsg{}
   342  	return reflect.ValueOf(m)
   343  }
   344  
   345  func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   346  	m := &keyUpdateMsg{}
   347  	m.updateRequested = rand.Intn(10) > 5
   348  	return reflect.ValueOf(m)
   349  }
   350  
   351  func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   352  	m := &newSessionTicketMsgTLS13{}
   353  	m.lifetime = uint32(rand.Intn(500000))
   354  	m.ageAdd = uint32(rand.Intn(500000))
   355  	m.nonce = randomBytes(rand.Intn(100), rand)
   356  	m.label = randomBytes(rand.Intn(1000), rand)
   357  	if rand.Intn(10) > 5 {
   358  		m.maxEarlyData = uint32(rand.Intn(500000))
   359  	}
   360  	return reflect.ValueOf(m)
   361  }
   362  
   363  func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   364  	m := &certificateRequestMsgTLS13{}
   365  	if rand.Intn(10) > 5 {
   366  		m.ocspStapling = true
   367  	}
   368  	if rand.Intn(10) > 5 {
   369  		m.scts = true
   370  	}
   371  	if rand.Intn(10) > 5 {
   372  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
   373  	}
   374  	if rand.Intn(10) > 5 {
   375  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms
   376  	}
   377  	if rand.Intn(10) > 5 {
   378  		m.certificateAuthorities = make([][]byte, 3)
   379  		for i := 0; i < 3; i++ {
   380  			m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
   381  		}
   382  	}
   383  	return reflect.ValueOf(m)
   384  }
   385  
   386  func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   387  	m := &certificateMsgTLS13{}
   388  	for i := 0; i < rand.Intn(2)+1; i++ {
   389  		m.certificate.Certificate = append(
   390  			m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
   391  	}
   392  	if rand.Intn(10) > 5 {
   393  		m.ocspStapling = true
   394  		m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
   395  	}
   396  	if rand.Intn(10) > 5 {
   397  		m.scts = true
   398  		for i := 0; i < rand.Intn(2)+1; i++ {
   399  			m.certificate.SignedCertificateTimestamps = append(
   400  				m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
   401  		}
   402  	}
   403  	return reflect.ValueOf(m)
   404  }
   405  
   406  func TestRejectEmptySCTList(t *testing.T) {
   407  	// RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
   408  
   409  	var random [32]byte
   410  	sct := []byte{0x42, 0x42, 0x42, 0x42}
   411  	serverHello := serverHelloMsg{
   412  		vers:   VersionTLS12,
   413  		random: random[:],
   414  		scts:   [][]byte{sct},
   415  	}
   416  	serverHelloBytes := serverHello.marshal()
   417  
   418  	var serverHelloCopy serverHelloMsg
   419  	if !serverHelloCopy.unmarshal(serverHelloBytes) {
   420  		t.Fatal("Failed to unmarshal initial message")
   421  	}
   422  
   423  	// Change serverHelloBytes so that the SCT list is empty
   424  	i := bytes.Index(serverHelloBytes, sct)
   425  	if i < 0 {
   426  		t.Fatal("Cannot find SCT in ServerHello")
   427  	}
   428  
   429  	var serverHelloEmptySCT []byte
   430  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
   431  	// Append the extension length and SCT list length for an empty list.
   432  	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
   433  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
   434  
   435  	// Update the handshake message length.
   436  	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
   437  	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
   438  	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
   439  
   440  	// Update the extensions length
   441  	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
   442  	serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
   443  
   444  	if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
   445  		t.Fatal("Unmarshaled ServerHello with empty SCT list")
   446  	}
   447  }
   448  
   449  func TestRejectEmptySCT(t *testing.T) {
   450  	// Not only must the SCT list be non-empty, but the SCT elements must
   451  	// not be zero length.
   452  
   453  	var random [32]byte
   454  	serverHello := serverHelloMsg{
   455  		vers:   VersionTLS12,
   456  		random: random[:],
   457  		scts:   [][]byte{nil},
   458  	}
   459  	serverHelloBytes := serverHello.marshal()
   460  
   461  	var serverHelloCopy serverHelloMsg
   462  	if serverHelloCopy.unmarshal(serverHelloBytes) {
   463  		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
   464  	}
   465  }
   466  

View as plain text