1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "context"
12 "crypto/cipher"
13 "crypto/subtle"
14 "crypto/x509"
15 "errors"
16 "fmt"
17 "hash"
18 "io"
19 "net"
20 "sync"
21 "sync/atomic"
22 "time"
23 )
24
25
26
27 type Conn struct {
28
29 conn net.Conn
30 isClient bool
31 handshakeFn func(context.Context) error
32 quic *quicState
33
34
35
36
37 isHandshakeComplete atomic.Bool
38
39 handshakeMutex sync.Mutex
40 handshakeErr error
41 vers uint16
42 haveVers bool
43 config *Config
44
45
46
47 handshakes int
48 extMasterSecret bool
49 didResume bool
50 cipherSuite uint16
51 ocspResponse []byte
52 scts [][]byte
53 peerCertificates []*x509.Certificate
54
55
56 activeCertHandles []*activeCert
57
58
59 verifiedChains [][]*x509.Certificate
60
61 serverName string
62
63
64
65 secureRenegotiation bool
66
67 ekm func(label string, context []byte, length int) ([]byte, error)
68
69
70 resumptionSecret []byte
71
72
73
74
75 ticketKeys []ticketKey
76
77
78
79
80
81 clientFinishedIsFirst bool
82
83
84 closeNotifyErr error
85
86
87 closeNotifySent bool
88
89
90
91
92
93 clientFinished [12]byte
94 serverFinished [12]byte
95
96
97 clientProtocol string
98
99
100 in, out halfConn
101 rawInput bytes.Buffer
102 input bytes.Reader
103 hand bytes.Buffer
104 buffering bool
105 sendBuf []byte
106
107
108
109 bytesSent int64
110 packetsSent int64
111
112
113
114
115 retryCount int
116
117
118
119 activeCall atomic.Int32
120
121 tmp [16]byte
122 }
123
124
125
126
127
128
129 func (c *Conn) LocalAddr() net.Addr {
130 return c.conn.LocalAddr()
131 }
132
133
134 func (c *Conn) RemoteAddr() net.Addr {
135 return c.conn.RemoteAddr()
136 }
137
138
139
140
141 func (c *Conn) SetDeadline(t time.Time) error {
142 return c.conn.SetDeadline(t)
143 }
144
145
146
147 func (c *Conn) SetReadDeadline(t time.Time) error {
148 return c.conn.SetReadDeadline(t)
149 }
150
151
152
153
154 func (c *Conn) SetWriteDeadline(t time.Time) error {
155 return c.conn.SetWriteDeadline(t)
156 }
157
158
159
160
161 func (c *Conn) NetConn() net.Conn {
162 return c.conn
163 }
164
165
166
167 type halfConn struct {
168 sync.Mutex
169
170 err error
171 version uint16
172 cipher any
173 mac hash.Hash
174 seq [8]byte
175
176 scratchBuf [13]byte
177
178 nextCipher any
179 nextMac hash.Hash
180
181 level QUICEncryptionLevel
182 trafficSecret []byte
183 }
184
185 type permanentError struct {
186 err net.Error
187 }
188
189 func (e *permanentError) Error() string { return e.err.Error() }
190 func (e *permanentError) Unwrap() error { return e.err }
191 func (e *permanentError) Timeout() bool { return e.err.Timeout() }
192 func (e *permanentError) Temporary() bool { return false }
193
194 func (hc *halfConn) setErrorLocked(err error) error {
195 if e, ok := err.(net.Error); ok {
196 hc.err = &permanentError{err: e}
197 } else {
198 hc.err = err
199 }
200 return hc.err
201 }
202
203
204
205 func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) {
206 hc.version = version
207 hc.nextCipher = cipher
208 hc.nextMac = mac
209 }
210
211
212
213 func (hc *halfConn) changeCipherSpec() error {
214 if hc.nextCipher == nil || hc.version == VersionTLS13 {
215 return alertInternalError
216 }
217 hc.cipher = hc.nextCipher
218 hc.mac = hc.nextMac
219 hc.nextCipher = nil
220 hc.nextMac = nil
221 for i := range hc.seq {
222 hc.seq[i] = 0
223 }
224 return nil
225 }
226
227 func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
228 hc.trafficSecret = secret
229 hc.level = level
230 key, iv := suite.trafficKey(secret)
231 hc.cipher = suite.aead(key, iv)
232 for i := range hc.seq {
233 hc.seq[i] = 0
234 }
235 }
236
237
238 func (hc *halfConn) incSeq() {
239 for i := 7; i >= 0; i-- {
240 hc.seq[i]++
241 if hc.seq[i] != 0 {
242 return
243 }
244 }
245
246
247
248
249 panic("TLS: sequence number wraparound")
250 }
251
252
253
254
255 func (hc *halfConn) explicitNonceLen() int {
256 if hc.cipher == nil {
257 return 0
258 }
259
260 switch c := hc.cipher.(type) {
261 case cipher.Stream:
262 return 0
263 case aead:
264 return c.explicitNonceLen()
265 case cbcMode:
266
267 if hc.version >= VersionTLS11 {
268 return c.BlockSize()
269 }
270 return 0
271 default:
272 panic("unknown cipher type")
273 }
274 }
275
276
277
278
279 func extractPadding(payload []byte) (toRemove int, good byte) {
280 if len(payload) < 1 {
281 return 0, 0
282 }
283
284 paddingLen := payload[len(payload)-1]
285 t := uint(len(payload)-1) - uint(paddingLen)
286
287 good = byte(int32(^t) >> 31)
288
289
290 toCheck := 256
291
292 if toCheck > len(payload) {
293 toCheck = len(payload)
294 }
295
296 for i := 0; i < toCheck; i++ {
297 t := uint(paddingLen) - uint(i)
298
299 mask := byte(int32(^t) >> 31)
300 b := payload[len(payload)-1-i]
301 good &^= mask&paddingLen ^ mask&b
302 }
303
304
305
306 good &= good << 4
307 good &= good << 2
308 good &= good << 1
309 good = uint8(int8(good) >> 7)
310
311
312
313
314
315
316
317
318
319
320 paddingLen &= good
321
322 toRemove = int(paddingLen) + 1
323 return
324 }
325
326 func roundUp(a, b int) int {
327 return a + (b-a%b)%b
328 }
329
330
331 type cbcMode interface {
332 cipher.BlockMode
333 SetIV([]byte)
334 }
335
336
337
338 func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
339 var plaintext []byte
340 typ := recordType(record[0])
341 payload := record[recordHeaderLen:]
342
343
344
345 if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
346 return payload, typ, nil
347 }
348
349 paddingGood := byte(255)
350 paddingLen := 0
351
352 explicitNonceLen := hc.explicitNonceLen()
353
354 if hc.cipher != nil {
355 switch c := hc.cipher.(type) {
356 case cipher.Stream:
357 c.XORKeyStream(payload, payload)
358 case aead:
359 if len(payload) < explicitNonceLen {
360 return nil, 0, alertBadRecordMAC
361 }
362 nonce := payload[:explicitNonceLen]
363 if len(nonce) == 0 {
364 nonce = hc.seq[:]
365 }
366 payload = payload[explicitNonceLen:]
367
368 var additionalData []byte
369 if hc.version == VersionTLS13 {
370 additionalData = record[:recordHeaderLen]
371 } else {
372 additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
373 additionalData = append(additionalData, record[:3]...)
374 n := len(payload) - c.Overhead()
375 additionalData = append(additionalData, byte(n>>8), byte(n))
376 }
377
378 var err error
379 plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
380 if err != nil {
381 return nil, 0, alertBadRecordMAC
382 }
383 case cbcMode:
384 blockSize := c.BlockSize()
385 minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
386 if len(payload)%blockSize != 0 || len(payload) < minPayload {
387 return nil, 0, alertBadRecordMAC
388 }
389
390 if explicitNonceLen > 0 {
391 c.SetIV(payload[:explicitNonceLen])
392 payload = payload[explicitNonceLen:]
393 }
394 c.CryptBlocks(payload, payload)
395
396
397
398
399
400
401
402 paddingLen, paddingGood = extractPadding(payload)
403 default:
404 panic("unknown cipher type")
405 }
406
407 if hc.version == VersionTLS13 {
408 if typ != recordTypeApplicationData {
409 return nil, 0, alertUnexpectedMessage
410 }
411 if len(plaintext) > maxPlaintext+1 {
412 return nil, 0, alertRecordOverflow
413 }
414
415 for i := len(plaintext) - 1; i >= 0; i-- {
416 if plaintext[i] != 0 {
417 typ = recordType(plaintext[i])
418 plaintext = plaintext[:i]
419 break
420 }
421 if i == 0 {
422 return nil, 0, alertUnexpectedMessage
423 }
424 }
425 }
426 } else {
427 plaintext = payload
428 }
429
430 if hc.mac != nil {
431 macSize := hc.mac.Size()
432 if len(payload) < macSize {
433 return nil, 0, alertBadRecordMAC
434 }
435
436 n := len(payload) - macSize - paddingLen
437 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
438 record[3] = byte(n >> 8)
439 record[4] = byte(n)
440 remoteMAC := payload[n : n+macSize]
441 localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
442
443
444
445
446
447
448
449
450 macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
451 if macAndPaddingGood != 1 {
452 return nil, 0, alertBadRecordMAC
453 }
454
455 plaintext = payload[:n]
456 }
457
458 hc.incSeq()
459 return plaintext, typ, nil
460 }
461
462
463
464
465 func sliceForAppend(in []byte, n int) (head, tail []byte) {
466 if total := len(in) + n; cap(in) >= total {
467 head = in[:total]
468 } else {
469 head = make([]byte, total)
470 copy(head, in)
471 }
472 tail = head[len(in):]
473 return
474 }
475
476
477
478 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
479 if hc.cipher == nil {
480 return append(record, payload...), nil
481 }
482
483 var explicitNonce []byte
484 if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
485 record, explicitNonce = sliceForAppend(record, explicitNonceLen)
486 if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
487
488
489
490
491
492
493
494
495
496 copy(explicitNonce, hc.seq[:])
497 } else {
498 if _, err := io.ReadFull(rand, explicitNonce); err != nil {
499 return nil, err
500 }
501 }
502 }
503
504 var dst []byte
505 switch c := hc.cipher.(type) {
506 case cipher.Stream:
507 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
508 record, dst = sliceForAppend(record, len(payload)+len(mac))
509 c.XORKeyStream(dst[:len(payload)], payload)
510 c.XORKeyStream(dst[len(payload):], mac)
511 case aead:
512 nonce := explicitNonce
513 if len(nonce) == 0 {
514 nonce = hc.seq[:]
515 }
516
517 if hc.version == VersionTLS13 {
518 record = append(record, payload...)
519
520
521 record = append(record, record[0])
522 record[0] = byte(recordTypeApplicationData)
523
524 n := len(payload) + 1 + c.Overhead()
525 record[3] = byte(n >> 8)
526 record[4] = byte(n)
527
528 record = c.Seal(record[:recordHeaderLen],
529 nonce, record[recordHeaderLen:], record[:recordHeaderLen])
530 } else {
531 additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
532 additionalData = append(additionalData, record[:recordHeaderLen]...)
533 record = c.Seal(record, nonce, payload, additionalData)
534 }
535 case cbcMode:
536 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
537 blockSize := c.BlockSize()
538 plaintextLen := len(payload) + len(mac)
539 paddingLen := blockSize - plaintextLen%blockSize
540 record, dst = sliceForAppend(record, plaintextLen+paddingLen)
541 copy(dst, payload)
542 copy(dst[len(payload):], mac)
543 for i := plaintextLen; i < len(dst); i++ {
544 dst[i] = byte(paddingLen - 1)
545 }
546 if len(explicitNonce) > 0 {
547 c.SetIV(explicitNonce)
548 }
549 c.CryptBlocks(dst, dst)
550 default:
551 panic("unknown cipher type")
552 }
553
554
555 n := len(record) - recordHeaderLen
556 record[3] = byte(n >> 8)
557 record[4] = byte(n)
558 hc.incSeq()
559
560 return record, nil
561 }
562
563
564 type RecordHeaderError struct {
565
566 Msg string
567
568
569 RecordHeader [5]byte
570
571
572
573
574 Conn net.Conn
575 }
576
577 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
578
579 func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
580 err.Msg = msg
581 err.Conn = conn
582 copy(err.RecordHeader[:], c.rawInput.Bytes())
583 return err
584 }
585
586 func (c *Conn) readRecord() error {
587 return c.readRecordOrCCS(false)
588 }
589
590 func (c *Conn) readChangeCipherSpec() error {
591 return c.readRecordOrCCS(true)
592 }
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608 func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
609 if c.in.err != nil {
610 return c.in.err
611 }
612 handshakeComplete := c.isHandshakeComplete.Load()
613
614
615 if c.input.Len() != 0 {
616 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
617 }
618 c.input.Reset(nil)
619
620 if c.quic != nil {
621 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
622 }
623
624
625 if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
626
627
628
629 if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
630 err = io.EOF
631 }
632 if e, ok := err.(net.Error); !ok || !e.Temporary() {
633 c.in.setErrorLocked(err)
634 }
635 return err
636 }
637 hdr := c.rawInput.Bytes()[:recordHeaderLen]
638 typ := recordType(hdr[0])
639
640
641
642
643
644 if !handshakeComplete && typ == 0x80 {
645 c.sendAlert(alertProtocolVersion)
646 return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
647 }
648
649 vers := uint16(hdr[1])<<8 | uint16(hdr[2])
650 expectedVers := c.vers
651 if expectedVers == VersionTLS13 {
652
653
654 expectedVers = VersionTLS12
655 }
656 n := int(hdr[3])<<8 | int(hdr[4])
657 if c.haveVers && vers != expectedVers {
658 c.sendAlert(alertProtocolVersion)
659 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, expectedVers)
660 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
661 }
662 if !c.haveVers {
663
664
665
666
667 if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
668 return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
669 }
670 }
671 if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
672 c.sendAlert(alertRecordOverflow)
673 msg := fmt.Sprintf("oversized record received with length %d", n)
674 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
675 }
676 if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
677 if e, ok := err.(net.Error); !ok || !e.Temporary() {
678 c.in.setErrorLocked(err)
679 }
680 return err
681 }
682
683
684 record := c.rawInput.Next(recordHeaderLen + n)
685 data, typ, err := c.in.decrypt(record)
686 if err != nil {
687 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
688 }
689 if len(data) > maxPlaintext {
690 return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
691 }
692
693
694 if c.in.cipher == nil && typ == recordTypeApplicationData {
695 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
696 }
697
698 if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
699
700 c.retryCount = 0
701 }
702
703
704 if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
705 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
706 }
707
708 switch typ {
709 default:
710 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
711
712 case recordTypeAlert:
713 if c.quic != nil {
714 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
715 }
716 if len(data) != 2 {
717 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
718 }
719 if alert(data[1]) == alertCloseNotify {
720 return c.in.setErrorLocked(io.EOF)
721 }
722 if c.vers == VersionTLS13 {
723 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
724 }
725 switch data[0] {
726 case alertLevelWarning:
727
728 return c.retryReadRecord(expectChangeCipherSpec)
729 case alertLevelError:
730 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
731 default:
732 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
733 }
734
735 case recordTypeChangeCipherSpec:
736 if len(data) != 1 || data[0] != 1 {
737 return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
738 }
739
740 if c.hand.Len() > 0 {
741 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
742 }
743
744
745
746
747
748 if c.vers == VersionTLS13 {
749 return c.retryReadRecord(expectChangeCipherSpec)
750 }
751 if !expectChangeCipherSpec {
752 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
753 }
754 if err := c.in.changeCipherSpec(); err != nil {
755 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
756 }
757
758 case recordTypeApplicationData:
759 if !handshakeComplete || expectChangeCipherSpec {
760 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
761 }
762
763
764 if len(data) == 0 {
765 return c.retryReadRecord(expectChangeCipherSpec)
766 }
767
768
769
770 c.input.Reset(data)
771
772 case recordTypeHandshake:
773 if len(data) == 0 || expectChangeCipherSpec {
774 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
775 }
776 c.hand.Write(data)
777 }
778
779 return nil
780 }
781
782
783
784 func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
785 c.retryCount++
786 if c.retryCount > maxUselessRecords {
787 c.sendAlert(alertUnexpectedMessage)
788 return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
789 }
790 return c.readRecordOrCCS(expectChangeCipherSpec)
791 }
792
793
794
795
796 type atLeastReader struct {
797 R io.Reader
798 N int64
799 }
800
801 func (r *atLeastReader) Read(p []byte) (int, error) {
802 if r.N <= 0 {
803 return 0, io.EOF
804 }
805 n, err := r.R.Read(p)
806 r.N -= int64(n)
807 if r.N > 0 && err == io.EOF {
808 return n, io.ErrUnexpectedEOF
809 }
810 if r.N <= 0 && err == nil {
811 return n, io.EOF
812 }
813 return n, err
814 }
815
816
817
818 func (c *Conn) readFromUntil(r io.Reader, n int) error {
819 if c.rawInput.Len() >= n {
820 return nil
821 }
822 needs := n - c.rawInput.Len()
823
824
825
826 c.rawInput.Grow(needs + bytes.MinRead)
827 _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
828 return err
829 }
830
831
832 func (c *Conn) sendAlertLocked(err alert) error {
833 if c.quic != nil {
834 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
835 }
836
837 switch err {
838 case alertNoRenegotiation, alertCloseNotify:
839 c.tmp[0] = alertLevelWarning
840 default:
841 c.tmp[0] = alertLevelError
842 }
843 c.tmp[1] = byte(err)
844
845 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
846 if err == alertCloseNotify {
847
848 return writeErr
849 }
850
851 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
852 }
853
854
855 func (c *Conn) sendAlert(err alert) error {
856 c.out.Lock()
857 defer c.out.Unlock()
858 return c.sendAlertLocked(err)
859 }
860
861 const (
862
863
864
865
866
867 tcpMSSEstimate = 1208
868
869
870
871
872 recordSizeBoostThreshold = 128 * 1024
873 )
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891 func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
892 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
893 return maxPlaintext
894 }
895
896 if c.bytesSent >= recordSizeBoostThreshold {
897 return maxPlaintext
898 }
899
900
901 payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
902 if c.out.cipher != nil {
903 switch ciph := c.out.cipher.(type) {
904 case cipher.Stream:
905 payloadBytes -= c.out.mac.Size()
906 case cipher.AEAD:
907 payloadBytes -= ciph.Overhead()
908 case cbcMode:
909 blockSize := ciph.BlockSize()
910
911
912 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
913
914
915 payloadBytes -= c.out.mac.Size()
916 default:
917 panic("unknown cipher type")
918 }
919 }
920 if c.vers == VersionTLS13 {
921 payloadBytes--
922 }
923
924
925 pkt := c.packetsSent
926 c.packetsSent++
927 if pkt > 1000 {
928 return maxPlaintext
929 }
930
931 n := payloadBytes * int(pkt+1)
932 if n > maxPlaintext {
933 n = maxPlaintext
934 }
935 return n
936 }
937
938 func (c *Conn) write(data []byte) (int, error) {
939 if c.buffering {
940 c.sendBuf = append(c.sendBuf, data...)
941 return len(data), nil
942 }
943
944 n, err := c.conn.Write(data)
945 c.bytesSent += int64(n)
946 return n, err
947 }
948
949 func (c *Conn) flush() (int, error) {
950 if len(c.sendBuf) == 0 {
951 return 0, nil
952 }
953
954 n, err := c.conn.Write(c.sendBuf)
955 c.bytesSent += int64(n)
956 c.sendBuf = nil
957 c.buffering = false
958 return n, err
959 }
960
961
962 var outBufPool = sync.Pool{
963 New: func() any {
964 return new([]byte)
965 },
966 }
967
968
969
970 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
971 if c.quic != nil {
972 if typ != recordTypeHandshake {
973 return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
974 }
975 c.quicWriteCryptoData(c.out.level, data)
976 if !c.buffering {
977 if _, err := c.flush(); err != nil {
978 return 0, err
979 }
980 }
981 return len(data), nil
982 }
983
984 outBufPtr := outBufPool.Get().(*[]byte)
985 outBuf := *outBufPtr
986 defer func() {
987
988
989
990
991
992 *outBufPtr = outBuf
993 outBufPool.Put(outBufPtr)
994 }()
995
996 var n int
997 for len(data) > 0 {
998 m := len(data)
999 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
1000 m = maxPayload
1001 }
1002
1003 _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
1004 outBuf[0] = byte(typ)
1005 vers := c.vers
1006 if vers == 0 {
1007
1008
1009 vers = VersionTLS10
1010 } else if vers == VersionTLS13 {
1011
1012
1013 vers = VersionTLS12
1014 }
1015 outBuf[1] = byte(vers >> 8)
1016 outBuf[2] = byte(vers)
1017 outBuf[3] = byte(m >> 8)
1018 outBuf[4] = byte(m)
1019
1020 var err error
1021 outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
1022 if err != nil {
1023 return n, err
1024 }
1025 if _, err := c.write(outBuf); err != nil {
1026 return n, err
1027 }
1028 n += m
1029 data = data[m:]
1030 }
1031
1032 if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
1033 if err := c.out.changeCipherSpec(); err != nil {
1034 return n, c.sendAlertLocked(err.(alert))
1035 }
1036 }
1037
1038 return n, nil
1039 }
1040
1041
1042
1043
1044 func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
1045 c.out.Lock()
1046 defer c.out.Unlock()
1047
1048 data, err := msg.marshal()
1049 if err != nil {
1050 return 0, err
1051 }
1052 if transcript != nil {
1053 transcript.Write(data)
1054 }
1055
1056 return c.writeRecordLocked(recordTypeHandshake, data)
1057 }
1058
1059
1060
1061 func (c *Conn) writeChangeCipherRecord() error {
1062 c.out.Lock()
1063 defer c.out.Unlock()
1064 _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
1065 return err
1066 }
1067
1068
1069 func (c *Conn) readHandshakeBytes(n int) error {
1070 if c.quic != nil {
1071 return c.quicReadHandshakeBytes(n)
1072 }
1073 for c.hand.Len() < n {
1074 if err := c.readRecord(); err != nil {
1075 return err
1076 }
1077 }
1078 return nil
1079 }
1080
1081
1082
1083
1084 func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
1085 if err := c.readHandshakeBytes(4); err != nil {
1086 return nil, err
1087 }
1088 data := c.hand.Bytes()
1089 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1090 if n > maxHandshake {
1091 c.sendAlertLocked(alertInternalError)
1092 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
1093 }
1094 if err := c.readHandshakeBytes(4 + n); err != nil {
1095 return nil, err
1096 }
1097 data = c.hand.Next(4 + n)
1098 return c.unmarshalHandshakeMessage(data, transcript)
1099 }
1100
1101 func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
1102 var m handshakeMessage
1103 switch data[0] {
1104 case typeHelloRequest:
1105 m = new(helloRequestMsg)
1106 case typeClientHello:
1107 m = new(clientHelloMsg)
1108 case typeServerHello:
1109 m = new(serverHelloMsg)
1110 case typeNewSessionTicket:
1111 if c.vers == VersionTLS13 {
1112 m = new(newSessionTicketMsgTLS13)
1113 } else {
1114 m = new(newSessionTicketMsg)
1115 }
1116 case typeCertificate:
1117 if c.vers == VersionTLS13 {
1118 m = new(certificateMsgTLS13)
1119 } else {
1120 m = new(certificateMsg)
1121 }
1122 case typeCertificateRequest:
1123 if c.vers == VersionTLS13 {
1124 m = new(certificateRequestMsgTLS13)
1125 } else {
1126 m = &certificateRequestMsg{
1127 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1128 }
1129 }
1130 case typeCertificateStatus:
1131 m = new(certificateStatusMsg)
1132 case typeServerKeyExchange:
1133 m = new(serverKeyExchangeMsg)
1134 case typeServerHelloDone:
1135 m = new(serverHelloDoneMsg)
1136 case typeClientKeyExchange:
1137 m = new(clientKeyExchangeMsg)
1138 case typeCertificateVerify:
1139 m = &certificateVerifyMsg{
1140 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1141 }
1142 case typeFinished:
1143 m = new(finishedMsg)
1144 case typeEncryptedExtensions:
1145 m = new(encryptedExtensionsMsg)
1146 case typeEndOfEarlyData:
1147 m = new(endOfEarlyDataMsg)
1148 case typeKeyUpdate:
1149 m = new(keyUpdateMsg)
1150 default:
1151 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1152 }
1153
1154
1155
1156
1157 data = append([]byte(nil), data...)
1158
1159 if !m.unmarshal(data) {
1160 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1161 }
1162
1163 if transcript != nil {
1164 transcript.Write(data)
1165 }
1166
1167 return m, nil
1168 }
1169
1170 var (
1171 errShutdown = errors.New("tls: protocol is shutdown")
1172 )
1173
1174
1175
1176
1177
1178
1179
1180 func (c *Conn) Write(b []byte) (int, error) {
1181
1182 for {
1183 x := c.activeCall.Load()
1184 if x&1 != 0 {
1185 return 0, net.ErrClosed
1186 }
1187 if c.activeCall.CompareAndSwap(x, x+2) {
1188 break
1189 }
1190 }
1191 defer c.activeCall.Add(-2)
1192
1193 if err := c.Handshake(); err != nil {
1194 return 0, err
1195 }
1196
1197 c.out.Lock()
1198 defer c.out.Unlock()
1199
1200 if err := c.out.err; err != nil {
1201 return 0, err
1202 }
1203
1204 if !c.isHandshakeComplete.Load() {
1205 return 0, alertInternalError
1206 }
1207
1208 if c.closeNotifySent {
1209 return 0, errShutdown
1210 }
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221 var m int
1222 if len(b) > 1 && c.vers == VersionTLS10 {
1223 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1224 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1225 if err != nil {
1226 return n, c.out.setErrorLocked(err)
1227 }
1228 m, b = 1, b[1:]
1229 }
1230 }
1231
1232 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1233 return n + m, c.out.setErrorLocked(err)
1234 }
1235
1236
1237 func (c *Conn) handleRenegotiation() error {
1238 if c.vers == VersionTLS13 {
1239 return errors.New("tls: internal error: unexpected renegotiation")
1240 }
1241
1242 msg, err := c.readHandshake(nil)
1243 if err != nil {
1244 return err
1245 }
1246
1247 helloReq, ok := msg.(*helloRequestMsg)
1248 if !ok {
1249 c.sendAlert(alertUnexpectedMessage)
1250 return unexpectedMessageError(helloReq, msg)
1251 }
1252
1253 if !c.isClient {
1254 return c.sendAlert(alertNoRenegotiation)
1255 }
1256
1257 switch c.config.Renegotiation {
1258 case RenegotiateNever:
1259 return c.sendAlert(alertNoRenegotiation)
1260 case RenegotiateOnceAsClient:
1261 if c.handshakes > 1 {
1262 return c.sendAlert(alertNoRenegotiation)
1263 }
1264 case RenegotiateFreelyAsClient:
1265
1266 default:
1267 c.sendAlert(alertInternalError)
1268 return errors.New("tls: unknown Renegotiation value")
1269 }
1270
1271 c.handshakeMutex.Lock()
1272 defer c.handshakeMutex.Unlock()
1273
1274 c.isHandshakeComplete.Store(false)
1275 if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
1276 c.handshakes++
1277 }
1278 return c.handshakeErr
1279 }
1280
1281
1282
1283 func (c *Conn) handlePostHandshakeMessage() error {
1284 if c.vers != VersionTLS13 {
1285 return c.handleRenegotiation()
1286 }
1287
1288 msg, err := c.readHandshake(nil)
1289 if err != nil {
1290 return err
1291 }
1292 c.retryCount++
1293 if c.retryCount > maxUselessRecords {
1294 c.sendAlert(alertUnexpectedMessage)
1295 return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1296 }
1297
1298 switch msg := msg.(type) {
1299 case *newSessionTicketMsgTLS13:
1300 return c.handleNewSessionTicket(msg)
1301 case *keyUpdateMsg:
1302 return c.handleKeyUpdate(msg)
1303 }
1304
1305
1306
1307
1308 c.sendAlert(alertUnexpectedMessage)
1309 return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1310 }
1311
1312 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1313 if c.quic != nil {
1314 c.sendAlert(alertUnexpectedMessage)
1315 return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
1316 }
1317
1318 cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1319 if cipherSuite == nil {
1320 return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1321 }
1322
1323 newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1324 c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1325
1326 if keyUpdate.updateRequested {
1327 c.out.Lock()
1328 defer c.out.Unlock()
1329
1330 msg := &keyUpdateMsg{}
1331 msgBytes, err := msg.marshal()
1332 if err != nil {
1333 return err
1334 }
1335 _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
1336 if err != nil {
1337
1338 c.out.setErrorLocked(err)
1339 return nil
1340 }
1341
1342 newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1343 c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1344 }
1345
1346 return nil
1347 }
1348
1349
1350
1351
1352
1353
1354
1355 func (c *Conn) Read(b []byte) (int, error) {
1356 if err := c.Handshake(); err != nil {
1357 return 0, err
1358 }
1359 if len(b) == 0 {
1360
1361
1362 return 0, nil
1363 }
1364
1365 c.in.Lock()
1366 defer c.in.Unlock()
1367
1368 for c.input.Len() == 0 {
1369 if err := c.readRecord(); err != nil {
1370 return 0, err
1371 }
1372 for c.hand.Len() > 0 {
1373 if err := c.handlePostHandshakeMessage(); err != nil {
1374 return 0, err
1375 }
1376 }
1377 }
1378
1379 n, _ := c.input.Read(b)
1380
1381
1382
1383
1384
1385
1386
1387
1388 if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1389 recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1390 if err := c.readRecord(); err != nil {
1391 return n, err
1392 }
1393 }
1394
1395 return n, nil
1396 }
1397
1398
1399 func (c *Conn) Close() error {
1400
1401 var x int32
1402 for {
1403 x = c.activeCall.Load()
1404 if x&1 != 0 {
1405 return net.ErrClosed
1406 }
1407 if c.activeCall.CompareAndSwap(x, x|1) {
1408 break
1409 }
1410 }
1411 if x != 0 {
1412
1413
1414
1415
1416
1417
1418 return c.conn.Close()
1419 }
1420
1421 var alertErr error
1422 if c.isHandshakeComplete.Load() {
1423 if err := c.closeNotify(); err != nil {
1424 alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
1425 }
1426 }
1427
1428 if err := c.conn.Close(); err != nil {
1429 return err
1430 }
1431 return alertErr
1432 }
1433
1434 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1435
1436
1437
1438
1439 func (c *Conn) CloseWrite() error {
1440 if !c.isHandshakeComplete.Load() {
1441 return errEarlyCloseWrite
1442 }
1443
1444 return c.closeNotify()
1445 }
1446
1447 func (c *Conn) closeNotify() error {
1448 c.out.Lock()
1449 defer c.out.Unlock()
1450
1451 if !c.closeNotifySent {
1452
1453 c.SetWriteDeadline(time.Now().Add(time.Second * 5))
1454 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1455 c.closeNotifySent = true
1456
1457 c.SetWriteDeadline(time.Now())
1458 }
1459 return c.closeNotifyErr
1460 }
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475 func (c *Conn) Handshake() error {
1476 return c.HandshakeContext(context.Background())
1477 }
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489 func (c *Conn) HandshakeContext(ctx context.Context) error {
1490
1491
1492 return c.handshakeContext(ctx)
1493 }
1494
1495 func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
1496
1497
1498
1499 if c.isHandshakeComplete.Load() {
1500 return nil
1501 }
1502
1503 handshakeCtx, cancel := context.WithCancel(ctx)
1504
1505
1506
1507 defer cancel()
1508
1509 if c.quic != nil {
1510 c.quic.cancelc = handshakeCtx.Done()
1511 c.quic.cancel = cancel
1512 } else if ctx.Done() != nil {
1513
1514
1515
1516
1517
1518 done := make(chan struct{})
1519 interruptRes := make(chan error, 1)
1520 defer func() {
1521 close(done)
1522 if ctxErr := <-interruptRes; ctxErr != nil {
1523
1524 ret = ctxErr
1525 }
1526 }()
1527 go func() {
1528 select {
1529 case <-handshakeCtx.Done():
1530
1531 _ = c.conn.Close()
1532 interruptRes <- handshakeCtx.Err()
1533 case <-done:
1534 interruptRes <- nil
1535 }
1536 }()
1537 }
1538
1539 c.handshakeMutex.Lock()
1540 defer c.handshakeMutex.Unlock()
1541
1542 if err := c.handshakeErr; err != nil {
1543 return err
1544 }
1545 if c.isHandshakeComplete.Load() {
1546 return nil
1547 }
1548
1549 c.in.Lock()
1550 defer c.in.Unlock()
1551
1552 c.handshakeErr = c.handshakeFn(handshakeCtx)
1553 if c.handshakeErr == nil {
1554 c.handshakes++
1555 } else {
1556
1557
1558 c.flush()
1559 }
1560
1561 if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
1562 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1563 }
1564 if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
1565 panic("tls: internal error: handshake returned an error but is marked successful")
1566 }
1567
1568 if c.quic != nil {
1569 if c.handshakeErr == nil {
1570 c.quicHandshakeComplete()
1571
1572
1573
1574 c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
1575 } else {
1576 var a alert
1577 c.out.Lock()
1578 if !errors.As(c.out.err, &a) {
1579 a = alertInternalError
1580 }
1581 c.out.Unlock()
1582
1583
1584
1585
1586 c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
1587 }
1588 close(c.quic.blockedc)
1589 close(c.quic.signalc)
1590 }
1591
1592 return c.handshakeErr
1593 }
1594
1595
1596 func (c *Conn) ConnectionState() ConnectionState {
1597 c.handshakeMutex.Lock()
1598 defer c.handshakeMutex.Unlock()
1599 return c.connectionStateLocked()
1600 }
1601
1602 func (c *Conn) connectionStateLocked() ConnectionState {
1603 var state ConnectionState
1604 state.HandshakeComplete = c.isHandshakeComplete.Load()
1605 state.Version = c.vers
1606 state.NegotiatedProtocol = c.clientProtocol
1607 state.DidResume = c.didResume
1608 state.NegotiatedProtocolIsMutual = true
1609 state.ServerName = c.serverName
1610 state.CipherSuite = c.cipherSuite
1611 state.PeerCertificates = c.peerCertificates
1612 state.VerifiedChains = c.verifiedChains
1613 state.SignedCertificateTimestamps = c.scts
1614 state.OCSPResponse = c.ocspResponse
1615 if (!c.didResume || c.extMasterSecret) && c.vers != VersionTLS13 {
1616 if c.clientFinishedIsFirst {
1617 state.TLSUnique = c.clientFinished[:]
1618 } else {
1619 state.TLSUnique = c.serverFinished[:]
1620 }
1621 }
1622 if c.config.Renegotiation != RenegotiateNever {
1623 state.ekm = noExportedKeyingMaterial
1624 } else {
1625 state.ekm = c.ekm
1626 }
1627 return state
1628 }
1629
1630
1631
1632 func (c *Conn) OCSPResponse() []byte {
1633 c.handshakeMutex.Lock()
1634 defer c.handshakeMutex.Unlock()
1635
1636 return c.ocspResponse
1637 }
1638
1639
1640
1641
1642 func (c *Conn) VerifyHostname(host string) error {
1643 c.handshakeMutex.Lock()
1644 defer c.handshakeMutex.Unlock()
1645 if !c.isClient {
1646 return errors.New("tls: VerifyHostname called on TLS server connection")
1647 }
1648 if !c.isHandshakeComplete.Load() {
1649 return errors.New("tls: handshake has not yet been performed")
1650 }
1651 if len(c.verifiedChains) == 0 {
1652 return errors.New("tls: handshake did not verify certificate chain")
1653 }
1654 return c.peerCertificates[0].VerifyHostname(host)
1655 }
1656
View as plain text