Source file
src/net/dial.go
1
2
3
4
5 package net
6
7 import (
8 "context"
9 "internal/godebug"
10 "internal/nettrace"
11 "syscall"
12 "time"
13 )
14
15 const (
16
17
18 defaultTCPKeepAlive = 15 * time.Second
19
20
21
22 defaultMPTCPEnabled = false
23 )
24
25 var multipathtcp = godebug.New("multipathtcp")
26
27
28 type mptcpStatus uint8
29
30 const (
31
32 mptcpUseDefault mptcpStatus = iota
33 mptcpEnabled
34 mptcpDisabled
35 )
36
37 func (m *mptcpStatus) get() bool {
38 switch *m {
39 case mptcpEnabled:
40 return true
41 case mptcpDisabled:
42 return false
43 }
44
45
46 if multipathtcp.Value() == "1" {
47 multipathtcp.IncNonDefault()
48
49 return true
50 }
51
52 return defaultMPTCPEnabled
53 }
54
55 func (m *mptcpStatus) set(use bool) {
56 if use {
57 *m = mptcpEnabled
58 } else {
59 *m = mptcpDisabled
60 }
61 }
62
63
64
65
66
67
68
69
70 type Dialer struct {
71
72
73
74
75
76
77
78
79
80
81
82
83 Timeout time.Duration
84
85
86
87
88
89 Deadline time.Time
90
91
92
93
94
95 LocalAddr Addr
96
97
98
99
100
101
102
103
104 DualStack bool
105
106
107
108
109
110
111
112
113
114 FallbackDelay time.Duration
115
116
117
118
119
120
121
122
123 KeepAlive time.Duration
124
125
126 Resolver *Resolver
127
128
129
130
131
132
133 Cancel <-chan struct{}
134
135
136
137
138
139
140
141
142
143 Control func(network, address string, c syscall.RawConn) error
144
145
146
147
148
149
150
151
152
153 ControlContext func(ctx context.Context, network, address string, c syscall.RawConn) error
154
155
156
157
158 mptcpStatus mptcpStatus
159 }
160
161 func (d *Dialer) dualStack() bool { return d.FallbackDelay >= 0 }
162
163 func minNonzeroTime(a, b time.Time) time.Time {
164 if a.IsZero() {
165 return b
166 }
167 if b.IsZero() || a.Before(b) {
168 return a
169 }
170 return b
171 }
172
173
174
175
176
177
178
179 func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
180 if d.Timeout != 0 {
181 earliest = now.Add(d.Timeout)
182 }
183 if d, ok := ctx.Deadline(); ok {
184 earliest = minNonzeroTime(earliest, d)
185 }
186 return minNonzeroTime(earliest, d.Deadline)
187 }
188
189 func (d *Dialer) resolver() *Resolver {
190 if d.Resolver != nil {
191 return d.Resolver
192 }
193 return DefaultResolver
194 }
195
196
197
198 func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
199 if deadline.IsZero() {
200 return deadline, nil
201 }
202 timeRemaining := deadline.Sub(now)
203 if timeRemaining <= 0 {
204 return time.Time{}, errTimeout
205 }
206
207 timeout := timeRemaining / time.Duration(addrsRemaining)
208
209 const saneMinimum = 2 * time.Second
210 if timeout < saneMinimum {
211 if timeRemaining < saneMinimum {
212 timeout = timeRemaining
213 } else {
214 timeout = saneMinimum
215 }
216 }
217 return now.Add(timeout), nil
218 }
219
220 func (d *Dialer) fallbackDelay() time.Duration {
221 if d.FallbackDelay > 0 {
222 return d.FallbackDelay
223 } else {
224 return 300 * time.Millisecond
225 }
226 }
227
228 func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet string, proto int, err error) {
229 i := last(network, ':')
230 if i < 0 {
231 switch network {
232 case "tcp", "tcp4", "tcp6":
233 case "udp", "udp4", "udp6":
234 case "ip", "ip4", "ip6":
235 if needsProto {
236 return "", 0, UnknownNetworkError(network)
237 }
238 case "unix", "unixgram", "unixpacket":
239 default:
240 return "", 0, UnknownNetworkError(network)
241 }
242 return network, 0, nil
243 }
244 afnet = network[:i]
245 switch afnet {
246 case "ip", "ip4", "ip6":
247 protostr := network[i+1:]
248 proto, i, ok := dtoi(protostr)
249 if !ok || i != len(protostr) {
250 proto, err = lookupProtocol(ctx, protostr)
251 if err != nil {
252 return "", 0, err
253 }
254 }
255 return afnet, proto, nil
256 }
257 return "", 0, UnknownNetworkError(network)
258 }
259
260
261
262
263 func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
264 afnet, _, err := parseNetwork(ctx, network, true)
265 if err != nil {
266 return nil, err
267 }
268 if op == "dial" && addr == "" {
269 return nil, errMissingAddress
270 }
271 switch afnet {
272 case "unix", "unixgram", "unixpacket":
273 addr, err := ResolveUnixAddr(afnet, addr)
274 if err != nil {
275 return nil, err
276 }
277 if op == "dial" && hint != nil && addr.Network() != hint.Network() {
278 return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
279 }
280 return addrList{addr}, nil
281 }
282 addrs, err := r.internetAddrList(ctx, afnet, addr)
283 if err != nil || op != "dial" || hint == nil {
284 return addrs, err
285 }
286 var (
287 tcp *TCPAddr
288 udp *UDPAddr
289 ip *IPAddr
290 wildcard bool
291 )
292 switch hint := hint.(type) {
293 case *TCPAddr:
294 tcp = hint
295 wildcard = tcp.isWildcard()
296 case *UDPAddr:
297 udp = hint
298 wildcard = udp.isWildcard()
299 case *IPAddr:
300 ip = hint
301 wildcard = ip.isWildcard()
302 }
303 naddrs := addrs[:0]
304 for _, addr := range addrs {
305 if addr.Network() != hint.Network() {
306 return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
307 }
308 switch addr := addr.(type) {
309 case *TCPAddr:
310 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(tcp.IP) {
311 continue
312 }
313 naddrs = append(naddrs, addr)
314 case *UDPAddr:
315 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(udp.IP) {
316 continue
317 }
318 naddrs = append(naddrs, addr)
319 case *IPAddr:
320 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(ip.IP) {
321 continue
322 }
323 naddrs = append(naddrs, addr)
324 }
325 }
326 if len(naddrs) == 0 {
327 return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: hint.String()}
328 }
329 return naddrs, nil
330 }
331
332
333
334
335
336 func (d *Dialer) MultipathTCP() bool {
337 return d.mptcpStatus.get()
338 }
339
340
341
342
343
344
345
346 func (d *Dialer) SetMultipathTCP(use bool) {
347 d.mptcpStatus.set(use)
348 }
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398 func Dial(network, address string) (Conn, error) {
399 var d Dialer
400 return d.Dial(network, address)
401 }
402
403
404
405
406
407
408
409
410
411
412
413 func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
414 d := Dialer{Timeout: timeout}
415 return d.Dial(network, address)
416 }
417
418
419 type sysDialer struct {
420 Dialer
421 network, address string
422 testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
423 }
424
425
426
427
428
429
430
431
432 func (d *Dialer) Dial(network, address string) (Conn, error) {
433 return d.DialContext(context.Background(), network, address)
434 }
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454 func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
455 if ctx == nil {
456 panic("nil context")
457 }
458 deadline := d.deadline(ctx, time.Now())
459 if !deadline.IsZero() {
460 if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
461 subCtx, cancel := context.WithDeadline(ctx, deadline)
462 defer cancel()
463 ctx = subCtx
464 }
465 }
466 if oldCancel := d.Cancel; oldCancel != nil {
467 subCtx, cancel := context.WithCancel(ctx)
468 defer cancel()
469 go func() {
470 select {
471 case <-oldCancel:
472 cancel()
473 case <-subCtx.Done():
474 }
475 }()
476 ctx = subCtx
477 }
478
479
480 resolveCtx := ctx
481 if trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace); trace != nil {
482 shadow := *trace
483 shadow.ConnectStart = nil
484 shadow.ConnectDone = nil
485 resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow)
486 }
487
488 addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
489 if err != nil {
490 return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
491 }
492
493 sd := &sysDialer{
494 Dialer: *d,
495 network: network,
496 address: address,
497 }
498
499 var primaries, fallbacks addrList
500 if d.dualStack() && network == "tcp" {
501 primaries, fallbacks = addrs.partition(isIPv4)
502 } else {
503 primaries = addrs
504 }
505
506 return sd.dialParallel(ctx, primaries, fallbacks)
507 }
508
509
510
511
512
513 func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addrList) (Conn, error) {
514 if len(fallbacks) == 0 {
515 return sd.dialSerial(ctx, primaries)
516 }
517
518 returned := make(chan struct{})
519 defer close(returned)
520
521 type dialResult struct {
522 Conn
523 error
524 primary bool
525 done bool
526 }
527 results := make(chan dialResult)
528
529 startRacer := func(ctx context.Context, primary bool) {
530 ras := primaries
531 if !primary {
532 ras = fallbacks
533 }
534 c, err := sd.dialSerial(ctx, ras)
535 select {
536 case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
537 case <-returned:
538 if c != nil {
539 c.Close()
540 }
541 }
542 }
543
544 var primary, fallback dialResult
545
546
547 primaryCtx, primaryCancel := context.WithCancel(ctx)
548 defer primaryCancel()
549 go startRacer(primaryCtx, true)
550
551
552 fallbackTimer := time.NewTimer(sd.fallbackDelay())
553 defer fallbackTimer.Stop()
554
555 for {
556 select {
557 case <-fallbackTimer.C:
558 fallbackCtx, fallbackCancel := context.WithCancel(ctx)
559 defer fallbackCancel()
560 go startRacer(fallbackCtx, false)
561
562 case res := <-results:
563 if res.error == nil {
564 return res.Conn, nil
565 }
566 if res.primary {
567 primary = res
568 } else {
569 fallback = res
570 }
571 if primary.done && fallback.done {
572 return nil, primary.error
573 }
574 if res.primary && fallbackTimer.Stop() {
575
576
577
578
579 fallbackTimer.Reset(0)
580 }
581 }
582 }
583 }
584
585
586
587 func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) {
588 var firstErr error
589
590 for i, ra := range ras {
591 select {
592 case <-ctx.Done():
593 return nil, &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
594 default:
595 }
596
597 dialCtx := ctx
598 if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
599 partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
600 if err != nil {
601
602 if firstErr == nil {
603 firstErr = &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: err}
604 }
605 break
606 }
607 if partialDeadline.Before(deadline) {
608 var cancel context.CancelFunc
609 dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
610 defer cancel()
611 }
612 }
613
614 c, err := sd.dialSingle(dialCtx, ra)
615 if err == nil {
616 return c, nil
617 }
618 if firstErr == nil {
619 firstErr = err
620 }
621 }
622
623 if firstErr == nil {
624 firstErr = &OpError{Op: "dial", Net: sd.network, Source: nil, Addr: nil, Err: errMissingAddress}
625 }
626 return nil, firstErr
627 }
628
629
630
631 func (sd *sysDialer) dialSingle(ctx context.Context, ra Addr) (c Conn, err error) {
632 trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
633 if trace != nil {
634 raStr := ra.String()
635 if trace.ConnectStart != nil {
636 trace.ConnectStart(sd.network, raStr)
637 }
638 if trace.ConnectDone != nil {
639 defer func() { trace.ConnectDone(sd.network, raStr, err) }()
640 }
641 }
642 la := sd.LocalAddr
643 switch ra := ra.(type) {
644 case *TCPAddr:
645 la, _ := la.(*TCPAddr)
646 if sd.MultipathTCP() {
647 c, err = sd.dialMPTCP(ctx, la, ra)
648 } else {
649 c, err = sd.dialTCP(ctx, la, ra)
650 }
651 case *UDPAddr:
652 la, _ := la.(*UDPAddr)
653 c, err = sd.dialUDP(ctx, la, ra)
654 case *IPAddr:
655 la, _ := la.(*IPAddr)
656 c, err = sd.dialIP(ctx, la, ra)
657 case *UnixAddr:
658 la, _ := la.(*UnixAddr)
659 c, err = sd.dialUnix(ctx, la, ra)
660 default:
661 return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: sd.address}}
662 }
663 if err != nil {
664 return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: err}
665 }
666 return c, nil
667 }
668
669
670 type ListenConfig struct {
671
672
673
674
675
676
677 Control func(network, address string, c syscall.RawConn) error
678
679
680
681
682
683
684
685 KeepAlive time.Duration
686
687
688
689
690 mptcpStatus mptcpStatus
691 }
692
693
694
695
696
697 func (lc *ListenConfig) MultipathTCP() bool {
698 return lc.mptcpStatus.get()
699 }
700
701
702
703
704
705
706
707 func (lc *ListenConfig) SetMultipathTCP(use bool) {
708 lc.mptcpStatus.set(use)
709 }
710
711
712
713
714
715 func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) {
716 addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
717 if err != nil {
718 return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
719 }
720 sl := &sysListener{
721 ListenConfig: *lc,
722 network: network,
723 address: address,
724 }
725 var l Listener
726 la := addrs.first(isIPv4)
727 switch la := la.(type) {
728 case *TCPAddr:
729 if sl.MultipathTCP() {
730 l, err = sl.listenMPTCP(ctx, la)
731 } else {
732 l, err = sl.listenTCP(ctx, la)
733 }
734 case *UnixAddr:
735 l, err = sl.listenUnix(ctx, la)
736 default:
737 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
738 }
739 if err != nil {
740 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err}
741 }
742 return l, nil
743 }
744
745
746
747
748
749 func (lc *ListenConfig) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) {
750 addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
751 if err != nil {
752 return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
753 }
754 sl := &sysListener{
755 ListenConfig: *lc,
756 network: network,
757 address: address,
758 }
759 var c PacketConn
760 la := addrs.first(isIPv4)
761 switch la := la.(type) {
762 case *UDPAddr:
763 c, err = sl.listenUDP(ctx, la)
764 case *IPAddr:
765 c, err = sl.listenIP(ctx, la)
766 case *UnixAddr:
767 c, err = sl.listenUnixgram(ctx, la)
768 default:
769 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
770 }
771 if err != nil {
772 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err}
773 }
774 return c, nil
775 }
776
777
778 type sysListener struct {
779 ListenConfig
780 network, address string
781 }
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804 func Listen(network, address string) (Listener, error) {
805 var lc ListenConfig
806 return lc.Listen(context.Background(), network, address)
807 }
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834 func ListenPacket(network, address string) (PacketConn, error) {
835 var lc ListenConfig
836 return lc.ListenPacket(context.Background(), network, address)
837 }
838
View as plain text