1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "context"
12 "crypto/cipher"
13 "crypto/subtle"
14 "crypto/x509"
15 "errors"
16 "fmt"
17 "hash"
18 "io"
19 "net"
20 "sync"
21 "sync/atomic"
22 "time"
23 )
24
25
26
27 type Conn struct {
28
29 conn net.Conn
30 isClient bool
31 handshakeFn func(context.Context) error
32
33
34
35
36
37 handshakeStatus uint32
38
39 handshakeMutex sync.Mutex
40 handshakeErr error
41 vers uint16
42 haveVers bool
43 config *Config
44
45
46
47 handshakes int
48 didResume bool
49 cipherSuite uint16
50 ocspResponse []byte
51 scts [][]byte
52 peerCertificates []*x509.Certificate
53
54
55 verifiedChains [][]*x509.Certificate
56
57 serverName string
58
59
60
61 secureRenegotiation bool
62
63 ekm func(label string, context []byte, length int) ([]byte, error)
64
65
66 resumptionSecret []byte
67
68
69
70
71 ticketKeys []ticketKey
72
73
74
75
76
77 clientFinishedIsFirst bool
78
79
80 closeNotifyErr error
81
82
83 closeNotifySent bool
84
85
86
87
88
89 clientFinished [12]byte
90 serverFinished [12]byte
91
92
93 clientProtocol string
94
95
96 in, out halfConn
97 rawInput bytes.Buffer
98 input bytes.Reader
99 hand bytes.Buffer
100 buffering bool
101 sendBuf []byte
102
103
104
105 bytesSent int64
106 packetsSent int64
107
108
109
110
111 retryCount int
112
113
114
115
116 activeCall int32
117
118 tmp [16]byte
119 }
120
121
122
123
124
125
126 func (c *Conn) LocalAddr() net.Addr {
127 return c.conn.LocalAddr()
128 }
129
130
131 func (c *Conn) RemoteAddr() net.Addr {
132 return c.conn.RemoteAddr()
133 }
134
135
136
137
138 func (c *Conn) SetDeadline(t time.Time) error {
139 return c.conn.SetDeadline(t)
140 }
141
142
143
144 func (c *Conn) SetReadDeadline(t time.Time) error {
145 return c.conn.SetReadDeadline(t)
146 }
147
148
149
150
151 func (c *Conn) SetWriteDeadline(t time.Time) error {
152 return c.conn.SetWriteDeadline(t)
153 }
154
155
156
157
158 func (c *Conn) NetConn() net.Conn {
159 return c.conn
160 }
161
162
163
164 type halfConn struct {
165 sync.Mutex
166
167 err error
168 version uint16
169 cipher any
170 mac hash.Hash
171 seq [8]byte
172
173 scratchBuf [13]byte
174
175 nextCipher any
176 nextMac hash.Hash
177
178 trafficSecret []byte
179 }
180
181 type permanentError struct {
182 err net.Error
183 }
184
185 func (e *permanentError) Error() string { return e.err.Error() }
186 func (e *permanentError) Unwrap() error { return e.err }
187 func (e *permanentError) Timeout() bool { return e.err.Timeout() }
188 func (e *permanentError) Temporary() bool { return false }
189
190 func (hc *halfConn) setErrorLocked(err error) error {
191 if e, ok := err.(net.Error); ok {
192 hc.err = &permanentError{err: e}
193 } else {
194 hc.err = err
195 }
196 return hc.err
197 }
198
199
200
201 func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) {
202 hc.version = version
203 hc.nextCipher = cipher
204 hc.nextMac = mac
205 }
206
207
208
209 func (hc *halfConn) changeCipherSpec() error {
210 if hc.nextCipher == nil || hc.version == VersionTLS13 {
211 return alertInternalError
212 }
213 hc.cipher = hc.nextCipher
214 hc.mac = hc.nextMac
215 hc.nextCipher = nil
216 hc.nextMac = nil
217 for i := range hc.seq {
218 hc.seq[i] = 0
219 }
220 return nil
221 }
222
223 func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) {
224 hc.trafficSecret = secret
225 key, iv := suite.trafficKey(secret)
226 hc.cipher = suite.aead(key, iv)
227 for i := range hc.seq {
228 hc.seq[i] = 0
229 }
230 }
231
232
233 func (hc *halfConn) incSeq() {
234 for i := 7; i >= 0; i-- {
235 hc.seq[i]++
236 if hc.seq[i] != 0 {
237 return
238 }
239 }
240
241
242
243
244 panic("TLS: sequence number wraparound")
245 }
246
247
248
249
250 func (hc *halfConn) explicitNonceLen() int {
251 if hc.cipher == nil {
252 return 0
253 }
254
255 switch c := hc.cipher.(type) {
256 case cipher.Stream:
257 return 0
258 case aead:
259 return c.explicitNonceLen()
260 case cbcMode:
261
262 if hc.version >= VersionTLS11 {
263 return c.BlockSize()
264 }
265 return 0
266 default:
267 panic("unknown cipher type")
268 }
269 }
270
271
272
273
274 func extractPadding(payload []byte) (toRemove int, good byte) {
275 if len(payload) < 1 {
276 return 0, 0
277 }
278
279 paddingLen := payload[len(payload)-1]
280 t := uint(len(payload)-1) - uint(paddingLen)
281
282 good = byte(int32(^t) >> 31)
283
284
285 toCheck := 256
286
287 if toCheck > len(payload) {
288 toCheck = len(payload)
289 }
290
291 for i := 0; i < toCheck; i++ {
292 t := uint(paddingLen) - uint(i)
293
294 mask := byte(int32(^t) >> 31)
295 b := payload[len(payload)-1-i]
296 good &^= mask&paddingLen ^ mask&b
297 }
298
299
300
301 good &= good << 4
302 good &= good << 2
303 good &= good << 1
304 good = uint8(int8(good) >> 7)
305
306
307
308
309
310
311
312
313
314
315 paddingLen &= good
316
317 toRemove = int(paddingLen) + 1
318 return
319 }
320
321 func roundUp(a, b int) int {
322 return a + (b-a%b)%b
323 }
324
325
326 type cbcMode interface {
327 cipher.BlockMode
328 SetIV([]byte)
329 }
330
331
332
333 func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
334 var plaintext []byte
335 typ := recordType(record[0])
336 payload := record[recordHeaderLen:]
337
338
339
340 if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
341 return payload, typ, nil
342 }
343
344 paddingGood := byte(255)
345 paddingLen := 0
346
347 explicitNonceLen := hc.explicitNonceLen()
348
349 if hc.cipher != nil {
350 switch c := hc.cipher.(type) {
351 case cipher.Stream:
352 c.XORKeyStream(payload, payload)
353 case aead:
354 if len(payload) < explicitNonceLen {
355 return nil, 0, alertBadRecordMAC
356 }
357 nonce := payload[:explicitNonceLen]
358 if len(nonce) == 0 {
359 nonce = hc.seq[:]
360 }
361 payload = payload[explicitNonceLen:]
362
363 var additionalData []byte
364 if hc.version == VersionTLS13 {
365 additionalData = record[:recordHeaderLen]
366 } else {
367 additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
368 additionalData = append(additionalData, record[:3]...)
369 n := len(payload) - c.Overhead()
370 additionalData = append(additionalData, byte(n>>8), byte(n))
371 }
372
373 var err error
374 plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
375 if err != nil {
376 return nil, 0, alertBadRecordMAC
377 }
378 case cbcMode:
379 blockSize := c.BlockSize()
380 minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
381 if len(payload)%blockSize != 0 || len(payload) < minPayload {
382 return nil, 0, alertBadRecordMAC
383 }
384
385 if explicitNonceLen > 0 {
386 c.SetIV(payload[:explicitNonceLen])
387 payload = payload[explicitNonceLen:]
388 }
389 c.CryptBlocks(payload, payload)
390
391
392
393
394
395
396
397 paddingLen, paddingGood = extractPadding(payload)
398 default:
399 panic("unknown cipher type")
400 }
401
402 if hc.version == VersionTLS13 {
403 if typ != recordTypeApplicationData {
404 return nil, 0, alertUnexpectedMessage
405 }
406 if len(plaintext) > maxPlaintext+1 {
407 return nil, 0, alertRecordOverflow
408 }
409
410 for i := len(plaintext) - 1; i >= 0; i-- {
411 if plaintext[i] != 0 {
412 typ = recordType(plaintext[i])
413 plaintext = plaintext[:i]
414 break
415 }
416 if i == 0 {
417 return nil, 0, alertUnexpectedMessage
418 }
419 }
420 }
421 } else {
422 plaintext = payload
423 }
424
425 if hc.mac != nil {
426 macSize := hc.mac.Size()
427 if len(payload) < macSize {
428 return nil, 0, alertBadRecordMAC
429 }
430
431 n := len(payload) - macSize - paddingLen
432 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
433 record[3] = byte(n >> 8)
434 record[4] = byte(n)
435 remoteMAC := payload[n : n+macSize]
436 localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
437
438
439
440
441
442
443
444
445 macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
446 if macAndPaddingGood != 1 {
447 return nil, 0, alertBadRecordMAC
448 }
449
450 plaintext = payload[:n]
451 }
452
453 hc.incSeq()
454 return plaintext, typ, nil
455 }
456
457
458
459
460 func sliceForAppend(in []byte, n int) (head, tail []byte) {
461 if total := len(in) + n; cap(in) >= total {
462 head = in[:total]
463 } else {
464 head = make([]byte, total)
465 copy(head, in)
466 }
467 tail = head[len(in):]
468 return
469 }
470
471
472
473 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
474 if hc.cipher == nil {
475 return append(record, payload...), nil
476 }
477
478 var explicitNonce []byte
479 if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
480 record, explicitNonce = sliceForAppend(record, explicitNonceLen)
481 if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
482
483
484
485
486
487
488
489
490
491 copy(explicitNonce, hc.seq[:])
492 } else {
493 if _, err := io.ReadFull(rand, explicitNonce); err != nil {
494 return nil, err
495 }
496 }
497 }
498
499 var dst []byte
500 switch c := hc.cipher.(type) {
501 case cipher.Stream:
502 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
503 record, dst = sliceForAppend(record, len(payload)+len(mac))
504 c.XORKeyStream(dst[:len(payload)], payload)
505 c.XORKeyStream(dst[len(payload):], mac)
506 case aead:
507 nonce := explicitNonce
508 if len(nonce) == 0 {
509 nonce = hc.seq[:]
510 }
511
512 if hc.version == VersionTLS13 {
513 record = append(record, payload...)
514
515
516 record = append(record, record[0])
517 record[0] = byte(recordTypeApplicationData)
518
519 n := len(payload) + 1 + c.Overhead()
520 record[3] = byte(n >> 8)
521 record[4] = byte(n)
522
523 record = c.Seal(record[:recordHeaderLen],
524 nonce, record[recordHeaderLen:], record[:recordHeaderLen])
525 } else {
526 additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
527 additionalData = append(additionalData, record[:recordHeaderLen]...)
528 record = c.Seal(record, nonce, payload, additionalData)
529 }
530 case cbcMode:
531 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
532 blockSize := c.BlockSize()
533 plaintextLen := len(payload) + len(mac)
534 paddingLen := blockSize - plaintextLen%blockSize
535 record, dst = sliceForAppend(record, plaintextLen+paddingLen)
536 copy(dst, payload)
537 copy(dst[len(payload):], mac)
538 for i := plaintextLen; i < len(dst); i++ {
539 dst[i] = byte(paddingLen - 1)
540 }
541 if len(explicitNonce) > 0 {
542 c.SetIV(explicitNonce)
543 }
544 c.CryptBlocks(dst, dst)
545 default:
546 panic("unknown cipher type")
547 }
548
549
550 n := len(record) - recordHeaderLen
551 record[3] = byte(n >> 8)
552 record[4] = byte(n)
553 hc.incSeq()
554
555 return record, nil
556 }
557
558
559 type RecordHeaderError struct {
560
561 Msg string
562
563
564 RecordHeader [5]byte
565
566
567
568
569 Conn net.Conn
570 }
571
572 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
573
574 func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
575 err.Msg = msg
576 err.Conn = conn
577 copy(err.RecordHeader[:], c.rawInput.Bytes())
578 return err
579 }
580
581 func (c *Conn) readRecord() error {
582 return c.readRecordOrCCS(false)
583 }
584
585 func (c *Conn) readChangeCipherSpec() error {
586 return c.readRecordOrCCS(true)
587 }
588
589
590
591
592
593
594
595
596
597
598
599
600
601 func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
602 if c.in.err != nil {
603 return c.in.err
604 }
605 handshakeComplete := c.handshakeComplete()
606
607
608 if c.input.Len() != 0 {
609 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
610 }
611 c.input.Reset(nil)
612
613
614 if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
615
616
617
618 if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
619 err = io.EOF
620 }
621 if e, ok := err.(net.Error); !ok || !e.Temporary() {
622 c.in.setErrorLocked(err)
623 }
624 return err
625 }
626 hdr := c.rawInput.Bytes()[:recordHeaderLen]
627 typ := recordType(hdr[0])
628
629
630
631
632
633 if !handshakeComplete && typ == 0x80 {
634 c.sendAlert(alertProtocolVersion)
635 return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
636 }
637
638 vers := uint16(hdr[1])<<8 | uint16(hdr[2])
639 n := int(hdr[3])<<8 | int(hdr[4])
640 if c.haveVers && c.vers != VersionTLS13 && vers != c.vers {
641 c.sendAlert(alertProtocolVersion)
642 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
643 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
644 }
645 if !c.haveVers {
646
647
648
649
650 if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
651 return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
652 }
653 }
654 if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
655 c.sendAlert(alertRecordOverflow)
656 msg := fmt.Sprintf("oversized record received with length %d", n)
657 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
658 }
659 if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
660 if e, ok := err.(net.Error); !ok || !e.Temporary() {
661 c.in.setErrorLocked(err)
662 }
663 return err
664 }
665
666
667 record := c.rawInput.Next(recordHeaderLen + n)
668 data, typ, err := c.in.decrypt(record)
669 if err != nil {
670 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
671 }
672 if len(data) > maxPlaintext {
673 return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
674 }
675
676
677 if c.in.cipher == nil && typ == recordTypeApplicationData {
678 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
679 }
680
681 if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
682
683 c.retryCount = 0
684 }
685
686
687 if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
688 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
689 }
690
691 switch typ {
692 default:
693 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
694
695 case recordTypeAlert:
696 if len(data) != 2 {
697 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
698 }
699 if alert(data[1]) == alertCloseNotify {
700 return c.in.setErrorLocked(io.EOF)
701 }
702 if c.vers == VersionTLS13 {
703 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
704 }
705 switch data[0] {
706 case alertLevelWarning:
707
708 return c.retryReadRecord(expectChangeCipherSpec)
709 case alertLevelError:
710 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
711 default:
712 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
713 }
714
715 case recordTypeChangeCipherSpec:
716 if len(data) != 1 || data[0] != 1 {
717 return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
718 }
719
720 if c.hand.Len() > 0 {
721 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
722 }
723
724
725
726
727
728 if c.vers == VersionTLS13 {
729 return c.retryReadRecord(expectChangeCipherSpec)
730 }
731 if !expectChangeCipherSpec {
732 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
733 }
734 if err := c.in.changeCipherSpec(); err != nil {
735 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
736 }
737
738 case recordTypeApplicationData:
739 if !handshakeComplete || expectChangeCipherSpec {
740 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
741 }
742
743
744 if len(data) == 0 {
745 return c.retryReadRecord(expectChangeCipherSpec)
746 }
747
748
749
750 c.input.Reset(data)
751
752 case recordTypeHandshake:
753 if len(data) == 0 || expectChangeCipherSpec {
754 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
755 }
756 c.hand.Write(data)
757 }
758
759 return nil
760 }
761
762
763
764 func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
765 c.retryCount++
766 if c.retryCount > maxUselessRecords {
767 c.sendAlert(alertUnexpectedMessage)
768 return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
769 }
770 return c.readRecordOrCCS(expectChangeCipherSpec)
771 }
772
773
774
775
776 type atLeastReader struct {
777 R io.Reader
778 N int64
779 }
780
781 func (r *atLeastReader) Read(p []byte) (int, error) {
782 if r.N <= 0 {
783 return 0, io.EOF
784 }
785 n, err := r.R.Read(p)
786 r.N -= int64(n)
787 if r.N > 0 && err == io.EOF {
788 return n, io.ErrUnexpectedEOF
789 }
790 if r.N <= 0 && err == nil {
791 return n, io.EOF
792 }
793 return n, err
794 }
795
796
797
798 func (c *Conn) readFromUntil(r io.Reader, n int) error {
799 if c.rawInput.Len() >= n {
800 return nil
801 }
802 needs := n - c.rawInput.Len()
803
804
805
806 c.rawInput.Grow(needs + bytes.MinRead)
807 _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
808 return err
809 }
810
811
812 func (c *Conn) sendAlertLocked(err alert) error {
813 switch err {
814 case alertNoRenegotiation, alertCloseNotify:
815 c.tmp[0] = alertLevelWarning
816 default:
817 c.tmp[0] = alertLevelError
818 }
819 c.tmp[1] = byte(err)
820
821 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
822 if err == alertCloseNotify {
823
824 return writeErr
825 }
826
827 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
828 }
829
830
831 func (c *Conn) sendAlert(err alert) error {
832 c.out.Lock()
833 defer c.out.Unlock()
834 return c.sendAlertLocked(err)
835 }
836
837 const (
838
839
840
841
842
843 tcpMSSEstimate = 1208
844
845
846
847
848 recordSizeBoostThreshold = 128 * 1024
849 )
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867 func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
868 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
869 return maxPlaintext
870 }
871
872 if c.bytesSent >= recordSizeBoostThreshold {
873 return maxPlaintext
874 }
875
876
877 payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
878 if c.out.cipher != nil {
879 switch ciph := c.out.cipher.(type) {
880 case cipher.Stream:
881 payloadBytes -= c.out.mac.Size()
882 case cipher.AEAD:
883 payloadBytes -= ciph.Overhead()
884 case cbcMode:
885 blockSize := ciph.BlockSize()
886
887
888 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
889
890
891 payloadBytes -= c.out.mac.Size()
892 default:
893 panic("unknown cipher type")
894 }
895 }
896 if c.vers == VersionTLS13 {
897 payloadBytes--
898 }
899
900
901 pkt := c.packetsSent
902 c.packetsSent++
903 if pkt > 1000 {
904 return maxPlaintext
905 }
906
907 n := payloadBytes * int(pkt+1)
908 if n > maxPlaintext {
909 n = maxPlaintext
910 }
911 return n
912 }
913
914 func (c *Conn) write(data []byte) (int, error) {
915 if c.buffering {
916 c.sendBuf = append(c.sendBuf, data...)
917 return len(data), nil
918 }
919
920 n, err := c.conn.Write(data)
921 c.bytesSent += int64(n)
922 return n, err
923 }
924
925 func (c *Conn) flush() (int, error) {
926 if len(c.sendBuf) == 0 {
927 return 0, nil
928 }
929
930 n, err := c.conn.Write(c.sendBuf)
931 c.bytesSent += int64(n)
932 c.sendBuf = nil
933 c.buffering = false
934 return n, err
935 }
936
937
938 var outBufPool = sync.Pool{
939 New: func() any {
940 return new([]byte)
941 },
942 }
943
944
945
946 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
947 outBufPtr := outBufPool.Get().(*[]byte)
948 outBuf := *outBufPtr
949 defer func() {
950
951
952
953
954
955 *outBufPtr = outBuf
956 outBufPool.Put(outBufPtr)
957 }()
958
959 var n int
960 for len(data) > 0 {
961 m := len(data)
962 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
963 m = maxPayload
964 }
965
966 _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
967 outBuf[0] = byte(typ)
968 vers := c.vers
969 if vers == 0 {
970
971
972 vers = VersionTLS10
973 } else if vers == VersionTLS13 {
974
975
976 vers = VersionTLS12
977 }
978 outBuf[1] = byte(vers >> 8)
979 outBuf[2] = byte(vers)
980 outBuf[3] = byte(m >> 8)
981 outBuf[4] = byte(m)
982
983 var err error
984 outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
985 if err != nil {
986 return n, err
987 }
988 if _, err := c.write(outBuf); err != nil {
989 return n, err
990 }
991 n += m
992 data = data[m:]
993 }
994
995 if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
996 if err := c.out.changeCipherSpec(); err != nil {
997 return n, c.sendAlertLocked(err.(alert))
998 }
999 }
1000
1001 return n, nil
1002 }
1003
1004
1005
1006 func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
1007 c.out.Lock()
1008 defer c.out.Unlock()
1009
1010 return c.writeRecordLocked(typ, data)
1011 }
1012
1013
1014
1015 func (c *Conn) readHandshake() (any, error) {
1016 for c.hand.Len() < 4 {
1017 if err := c.readRecord(); err != nil {
1018 return nil, err
1019 }
1020 }
1021
1022 data := c.hand.Bytes()
1023 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1024 if n > maxHandshake {
1025 c.sendAlertLocked(alertInternalError)
1026 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
1027 }
1028 for c.hand.Len() < 4+n {
1029 if err := c.readRecord(); err != nil {
1030 return nil, err
1031 }
1032 }
1033 data = c.hand.Next(4 + n)
1034 var m handshakeMessage
1035 switch data[0] {
1036 case typeHelloRequest:
1037 m = new(helloRequestMsg)
1038 case typeClientHello:
1039 m = new(clientHelloMsg)
1040 case typeServerHello:
1041 m = new(serverHelloMsg)
1042 case typeNewSessionTicket:
1043 if c.vers == VersionTLS13 {
1044 m = new(newSessionTicketMsgTLS13)
1045 } else {
1046 m = new(newSessionTicketMsg)
1047 }
1048 case typeCertificate:
1049 if c.vers == VersionTLS13 {
1050 m = new(certificateMsgTLS13)
1051 } else {
1052 m = new(certificateMsg)
1053 }
1054 case typeCertificateRequest:
1055 if c.vers == VersionTLS13 {
1056 m = new(certificateRequestMsgTLS13)
1057 } else {
1058 m = &certificateRequestMsg{
1059 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1060 }
1061 }
1062 case typeCertificateStatus:
1063 m = new(certificateStatusMsg)
1064 case typeServerKeyExchange:
1065 m = new(serverKeyExchangeMsg)
1066 case typeServerHelloDone:
1067 m = new(serverHelloDoneMsg)
1068 case typeClientKeyExchange:
1069 m = new(clientKeyExchangeMsg)
1070 case typeCertificateVerify:
1071 m = &certificateVerifyMsg{
1072 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1073 }
1074 case typeFinished:
1075 m = new(finishedMsg)
1076 case typeEncryptedExtensions:
1077 m = new(encryptedExtensionsMsg)
1078 case typeEndOfEarlyData:
1079 m = new(endOfEarlyDataMsg)
1080 case typeKeyUpdate:
1081 m = new(keyUpdateMsg)
1082 default:
1083 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1084 }
1085
1086
1087
1088
1089 data = append([]byte(nil), data...)
1090
1091 if !m.unmarshal(data) {
1092 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1093 }
1094 return m, nil
1095 }
1096
1097 var (
1098 errShutdown = errors.New("tls: protocol is shutdown")
1099 )
1100
1101
1102
1103
1104
1105
1106
1107 func (c *Conn) Write(b []byte) (int, error) {
1108
1109 for {
1110 x := atomic.LoadInt32(&c.activeCall)
1111 if x&1 != 0 {
1112 return 0, net.ErrClosed
1113 }
1114 if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
1115 break
1116 }
1117 }
1118 defer atomic.AddInt32(&c.activeCall, -2)
1119
1120 if err := c.Handshake(); err != nil {
1121 return 0, err
1122 }
1123
1124 c.out.Lock()
1125 defer c.out.Unlock()
1126
1127 if err := c.out.err; err != nil {
1128 return 0, err
1129 }
1130
1131 if !c.handshakeComplete() {
1132 return 0, alertInternalError
1133 }
1134
1135 if c.closeNotifySent {
1136 return 0, errShutdown
1137 }
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148 var m int
1149 if len(b) > 1 && c.vers == VersionTLS10 {
1150 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1151 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1152 if err != nil {
1153 return n, c.out.setErrorLocked(err)
1154 }
1155 m, b = 1, b[1:]
1156 }
1157 }
1158
1159 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1160 return n + m, c.out.setErrorLocked(err)
1161 }
1162
1163
1164 func (c *Conn) handleRenegotiation() error {
1165 if c.vers == VersionTLS13 {
1166 return errors.New("tls: internal error: unexpected renegotiation")
1167 }
1168
1169 msg, err := c.readHandshake()
1170 if err != nil {
1171 return err
1172 }
1173
1174 helloReq, ok := msg.(*helloRequestMsg)
1175 if !ok {
1176 c.sendAlert(alertUnexpectedMessage)
1177 return unexpectedMessageError(helloReq, msg)
1178 }
1179
1180 if !c.isClient {
1181 return c.sendAlert(alertNoRenegotiation)
1182 }
1183
1184 switch c.config.Renegotiation {
1185 case RenegotiateNever:
1186 return c.sendAlert(alertNoRenegotiation)
1187 case RenegotiateOnceAsClient:
1188 if c.handshakes > 1 {
1189 return c.sendAlert(alertNoRenegotiation)
1190 }
1191 case RenegotiateFreelyAsClient:
1192
1193 default:
1194 c.sendAlert(alertInternalError)
1195 return errors.New("tls: unknown Renegotiation value")
1196 }
1197
1198 c.handshakeMutex.Lock()
1199 defer c.handshakeMutex.Unlock()
1200
1201 atomic.StoreUint32(&c.handshakeStatus, 0)
1202 if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
1203 c.handshakes++
1204 }
1205 return c.handshakeErr
1206 }
1207
1208
1209
1210 func (c *Conn) handlePostHandshakeMessage() error {
1211 if c.vers != VersionTLS13 {
1212 return c.handleRenegotiation()
1213 }
1214
1215 msg, err := c.readHandshake()
1216 if err != nil {
1217 return err
1218 }
1219
1220 c.retryCount++
1221 if c.retryCount > maxUselessRecords {
1222 c.sendAlert(alertUnexpectedMessage)
1223 return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1224 }
1225
1226 switch msg := msg.(type) {
1227 case *newSessionTicketMsgTLS13:
1228 return c.handleNewSessionTicket(msg)
1229 case *keyUpdateMsg:
1230 return c.handleKeyUpdate(msg)
1231 default:
1232 c.sendAlert(alertUnexpectedMessage)
1233 return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1234 }
1235 }
1236
1237 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1238 cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1239 if cipherSuite == nil {
1240 return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1241 }
1242
1243 newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1244 c.in.setTrafficSecret(cipherSuite, newSecret)
1245
1246 if keyUpdate.updateRequested {
1247 c.out.Lock()
1248 defer c.out.Unlock()
1249
1250 msg := &keyUpdateMsg{}
1251 _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal())
1252 if err != nil {
1253
1254 c.out.setErrorLocked(err)
1255 return nil
1256 }
1257
1258 newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1259 c.out.setTrafficSecret(cipherSuite, newSecret)
1260 }
1261
1262 return nil
1263 }
1264
1265
1266
1267
1268
1269
1270
1271 func (c *Conn) Read(b []byte) (int, error) {
1272 if err := c.Handshake(); err != nil {
1273 return 0, err
1274 }
1275 if len(b) == 0 {
1276
1277
1278 return 0, nil
1279 }
1280
1281 c.in.Lock()
1282 defer c.in.Unlock()
1283
1284 for c.input.Len() == 0 {
1285 if err := c.readRecord(); err != nil {
1286 return 0, err
1287 }
1288 for c.hand.Len() > 0 {
1289 if err := c.handlePostHandshakeMessage(); err != nil {
1290 return 0, err
1291 }
1292 }
1293 }
1294
1295 n, _ := c.input.Read(b)
1296
1297
1298
1299
1300
1301
1302
1303
1304 if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1305 recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1306 if err := c.readRecord(); err != nil {
1307 return n, err
1308 }
1309 }
1310
1311 return n, nil
1312 }
1313
1314
1315 func (c *Conn) Close() error {
1316
1317 var x int32
1318 for {
1319 x = atomic.LoadInt32(&c.activeCall)
1320 if x&1 != 0 {
1321 return net.ErrClosed
1322 }
1323 if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) {
1324 break
1325 }
1326 }
1327 if x != 0 {
1328
1329
1330
1331
1332
1333
1334 return c.conn.Close()
1335 }
1336
1337 var alertErr error
1338 if c.handshakeComplete() {
1339 if err := c.closeNotify(); err != nil {
1340 alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
1341 }
1342 }
1343
1344 if err := c.conn.Close(); err != nil {
1345 return err
1346 }
1347 return alertErr
1348 }
1349
1350 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1351
1352
1353
1354
1355 func (c *Conn) CloseWrite() error {
1356 if !c.handshakeComplete() {
1357 return errEarlyCloseWrite
1358 }
1359
1360 return c.closeNotify()
1361 }
1362
1363 func (c *Conn) closeNotify() error {
1364 c.out.Lock()
1365 defer c.out.Unlock()
1366
1367 if !c.closeNotifySent {
1368
1369 c.SetWriteDeadline(time.Now().Add(time.Second * 5))
1370 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1371 c.closeNotifySent = true
1372
1373 c.SetWriteDeadline(time.Now())
1374 }
1375 return c.closeNotifyErr
1376 }
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386 func (c *Conn) Handshake() error {
1387 return c.HandshakeContext(context.Background())
1388 }
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400 func (c *Conn) HandshakeContext(ctx context.Context) error {
1401
1402
1403 return c.handshakeContext(ctx)
1404 }
1405
1406 func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
1407
1408
1409
1410 if c.handshakeComplete() {
1411 return nil
1412 }
1413
1414 handshakeCtx, cancel := context.WithCancel(ctx)
1415
1416
1417
1418 defer cancel()
1419
1420
1421
1422
1423
1424
1425 if ctx.Done() != nil {
1426 done := make(chan struct{})
1427 interruptRes := make(chan error, 1)
1428 defer func() {
1429 close(done)
1430 if ctxErr := <-interruptRes; ctxErr != nil {
1431
1432 ret = ctxErr
1433 }
1434 }()
1435 go func() {
1436 select {
1437 case <-handshakeCtx.Done():
1438
1439 _ = c.conn.Close()
1440 interruptRes <- handshakeCtx.Err()
1441 case <-done:
1442 interruptRes <- nil
1443 }
1444 }()
1445 }
1446
1447 c.handshakeMutex.Lock()
1448 defer c.handshakeMutex.Unlock()
1449
1450 if err := c.handshakeErr; err != nil {
1451 return err
1452 }
1453 if c.handshakeComplete() {
1454 return nil
1455 }
1456
1457 c.in.Lock()
1458 defer c.in.Unlock()
1459
1460 c.handshakeErr = c.handshakeFn(handshakeCtx)
1461 if c.handshakeErr == nil {
1462 c.handshakes++
1463 } else {
1464
1465
1466 c.flush()
1467 }
1468
1469 if c.handshakeErr == nil && !c.handshakeComplete() {
1470 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1471 }
1472 if c.handshakeErr != nil && c.handshakeComplete() {
1473 panic("tls: internal error: handshake returned an error but is marked successful")
1474 }
1475
1476 return c.handshakeErr
1477 }
1478
1479
1480 func (c *Conn) ConnectionState() ConnectionState {
1481 c.handshakeMutex.Lock()
1482 defer c.handshakeMutex.Unlock()
1483 return c.connectionStateLocked()
1484 }
1485
1486 func (c *Conn) connectionStateLocked() ConnectionState {
1487 var state ConnectionState
1488 state.HandshakeComplete = c.handshakeComplete()
1489 state.Version = c.vers
1490 state.NegotiatedProtocol = c.clientProtocol
1491 state.DidResume = c.didResume
1492 state.NegotiatedProtocolIsMutual = true
1493 state.ServerName = c.serverName
1494 state.CipherSuite = c.cipherSuite
1495 state.PeerCertificates = c.peerCertificates
1496 state.VerifiedChains = c.verifiedChains
1497 state.SignedCertificateTimestamps = c.scts
1498 state.OCSPResponse = c.ocspResponse
1499 if !c.didResume && c.vers != VersionTLS13 {
1500 if c.clientFinishedIsFirst {
1501 state.TLSUnique = c.clientFinished[:]
1502 } else {
1503 state.TLSUnique = c.serverFinished[:]
1504 }
1505 }
1506 if c.config.Renegotiation != RenegotiateNever {
1507 state.ekm = noExportedKeyingMaterial
1508 } else {
1509 state.ekm = c.ekm
1510 }
1511 return state
1512 }
1513
1514
1515
1516 func (c *Conn) OCSPResponse() []byte {
1517 c.handshakeMutex.Lock()
1518 defer c.handshakeMutex.Unlock()
1519
1520 return c.ocspResponse
1521 }
1522
1523
1524
1525
1526 func (c *Conn) VerifyHostname(host string) error {
1527 c.handshakeMutex.Lock()
1528 defer c.handshakeMutex.Unlock()
1529 if !c.isClient {
1530 return errors.New("tls: VerifyHostname called on TLS server connection")
1531 }
1532 if !c.handshakeComplete() {
1533 return errors.New("tls: handshake has not yet been performed")
1534 }
1535 if len(c.verifiedChains) == 0 {
1536 return errors.New("tls: handshake did not verify certificate chain")
1537 }
1538 return c.peerCertificates[0].VerifyHostname(host)
1539 }
1540
1541 func (c *Conn) handshakeComplete() bool {
1542 return atomic.LoadUint32(&c.handshakeStatus) == 1
1543 }
1544
View as plain text