Source file
src/net/dial.go
1
2
3
4
5 package net
6
7 import (
8 "context"
9 "internal/nettrace"
10 "syscall"
11 "time"
12 )
13
14
15
16 const (
17 defaultTCPKeepAlive = 15 * time.Second
18 )
19
20
21
22
23
24
25
26
27 type Dialer struct {
28
29
30
31
32
33
34
35
36
37
38
39
40 Timeout time.Duration
41
42
43
44
45
46 Deadline time.Time
47
48
49
50
51
52 LocalAddr Addr
53
54
55
56
57
58
59
60
61 DualStack bool
62
63
64
65
66
67
68
69
70
71 FallbackDelay time.Duration
72
73
74
75
76
77
78
79
80 KeepAlive time.Duration
81
82
83 Resolver *Resolver
84
85
86
87
88
89
90 Cancel <-chan struct{}
91
92
93
94
95
96
97
98
99
100 Control func(network, address string, c syscall.RawConn) error
101
102
103
104
105
106
107
108
109
110 ControlContext func(ctx context.Context, network, address string, c syscall.RawConn) error
111 }
112
113 func (d *Dialer) dualStack() bool { return d.FallbackDelay >= 0 }
114
115 func minNonzeroTime(a, b time.Time) time.Time {
116 if a.IsZero() {
117 return b
118 }
119 if b.IsZero() || a.Before(b) {
120 return a
121 }
122 return b
123 }
124
125
126
127
128
129
130
131 func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
132 if d.Timeout != 0 {
133 earliest = now.Add(d.Timeout)
134 }
135 if d, ok := ctx.Deadline(); ok {
136 earliest = minNonzeroTime(earliest, d)
137 }
138 return minNonzeroTime(earliest, d.Deadline)
139 }
140
141 func (d *Dialer) resolver() *Resolver {
142 if d.Resolver != nil {
143 return d.Resolver
144 }
145 return DefaultResolver
146 }
147
148
149
150 func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
151 if deadline.IsZero() {
152 return deadline, nil
153 }
154 timeRemaining := deadline.Sub(now)
155 if timeRemaining <= 0 {
156 return time.Time{}, errTimeout
157 }
158
159 timeout := timeRemaining / time.Duration(addrsRemaining)
160
161 const saneMinimum = 2 * time.Second
162 if timeout < saneMinimum {
163 if timeRemaining < saneMinimum {
164 timeout = timeRemaining
165 } else {
166 timeout = saneMinimum
167 }
168 }
169 return now.Add(timeout), nil
170 }
171
172 func (d *Dialer) fallbackDelay() time.Duration {
173 if d.FallbackDelay > 0 {
174 return d.FallbackDelay
175 } else {
176 return 300 * time.Millisecond
177 }
178 }
179
180 func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet string, proto int, err error) {
181 i := last(network, ':')
182 if i < 0 {
183 switch network {
184 case "tcp", "tcp4", "tcp6":
185 case "udp", "udp4", "udp6":
186 case "ip", "ip4", "ip6":
187 if needsProto {
188 return "", 0, UnknownNetworkError(network)
189 }
190 case "unix", "unixgram", "unixpacket":
191 default:
192 return "", 0, UnknownNetworkError(network)
193 }
194 return network, 0, nil
195 }
196 afnet = network[:i]
197 switch afnet {
198 case "ip", "ip4", "ip6":
199 protostr := network[i+1:]
200 proto, i, ok := dtoi(protostr)
201 if !ok || i != len(protostr) {
202 proto, err = lookupProtocol(ctx, protostr)
203 if err != nil {
204 return "", 0, err
205 }
206 }
207 return afnet, proto, nil
208 }
209 return "", 0, UnknownNetworkError(network)
210 }
211
212
213
214
215 func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
216 afnet, _, err := parseNetwork(ctx, network, true)
217 if err != nil {
218 return nil, err
219 }
220 if op == "dial" && addr == "" {
221 return nil, errMissingAddress
222 }
223 switch afnet {
224 case "unix", "unixgram", "unixpacket":
225 addr, err := ResolveUnixAddr(afnet, addr)
226 if err != nil {
227 return nil, err
228 }
229 if op == "dial" && hint != nil && addr.Network() != hint.Network() {
230 return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
231 }
232 return addrList{addr}, nil
233 }
234 addrs, err := r.internetAddrList(ctx, afnet, addr)
235 if err != nil || op != "dial" || hint == nil {
236 return addrs, err
237 }
238 var (
239 tcp *TCPAddr
240 udp *UDPAddr
241 ip *IPAddr
242 wildcard bool
243 )
244 switch hint := hint.(type) {
245 case *TCPAddr:
246 tcp = hint
247 wildcard = tcp.isWildcard()
248 case *UDPAddr:
249 udp = hint
250 wildcard = udp.isWildcard()
251 case *IPAddr:
252 ip = hint
253 wildcard = ip.isWildcard()
254 }
255 naddrs := addrs[:0]
256 for _, addr := range addrs {
257 if addr.Network() != hint.Network() {
258 return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
259 }
260 switch addr := addr.(type) {
261 case *TCPAddr:
262 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(tcp.IP) {
263 continue
264 }
265 naddrs = append(naddrs, addr)
266 case *UDPAddr:
267 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(udp.IP) {
268 continue
269 }
270 naddrs = append(naddrs, addr)
271 case *IPAddr:
272 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(ip.IP) {
273 continue
274 }
275 naddrs = append(naddrs, addr)
276 }
277 }
278 if len(naddrs) == 0 {
279 return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: hint.String()}
280 }
281 return naddrs, nil
282 }
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332 func Dial(network, address string) (Conn, error) {
333 var d Dialer
334 return d.Dial(network, address)
335 }
336
337
338
339
340
341
342
343
344
345
346
347 func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
348 d := Dialer{Timeout: timeout}
349 return d.Dial(network, address)
350 }
351
352
353 type sysDialer struct {
354 Dialer
355 network, address string
356 testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
357 }
358
359
360
361
362
363
364
365
366 func (d *Dialer) Dial(network, address string) (Conn, error) {
367 return d.DialContext(context.Background(), network, address)
368 }
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388 func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
389 if ctx == nil {
390 panic("nil context")
391 }
392 deadline := d.deadline(ctx, time.Now())
393 if !deadline.IsZero() {
394 if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
395 subCtx, cancel := context.WithDeadline(ctx, deadline)
396 defer cancel()
397 ctx = subCtx
398 }
399 }
400 if oldCancel := d.Cancel; oldCancel != nil {
401 subCtx, cancel := context.WithCancel(ctx)
402 defer cancel()
403 go func() {
404 select {
405 case <-oldCancel:
406 cancel()
407 case <-subCtx.Done():
408 }
409 }()
410 ctx = subCtx
411 }
412
413
414 resolveCtx := ctx
415 if trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace); trace != nil {
416 shadow := *trace
417 shadow.ConnectStart = nil
418 shadow.ConnectDone = nil
419 resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow)
420 }
421
422 addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
423 if err != nil {
424 return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
425 }
426
427 sd := &sysDialer{
428 Dialer: *d,
429 network: network,
430 address: address,
431 }
432
433 var primaries, fallbacks addrList
434 if d.dualStack() && network == "tcp" {
435 primaries, fallbacks = addrs.partition(isIPv4)
436 } else {
437 primaries = addrs
438 }
439
440 return sd.dialParallel(ctx, primaries, fallbacks)
441 }
442
443
444
445
446
447 func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addrList) (Conn, error) {
448 if len(fallbacks) == 0 {
449 return sd.dialSerial(ctx, primaries)
450 }
451
452 returned := make(chan struct{})
453 defer close(returned)
454
455 type dialResult struct {
456 Conn
457 error
458 primary bool
459 done bool
460 }
461 results := make(chan dialResult)
462
463 startRacer := func(ctx context.Context, primary bool) {
464 ras := primaries
465 if !primary {
466 ras = fallbacks
467 }
468 c, err := sd.dialSerial(ctx, ras)
469 select {
470 case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
471 case <-returned:
472 if c != nil {
473 c.Close()
474 }
475 }
476 }
477
478 var primary, fallback dialResult
479
480
481 primaryCtx, primaryCancel := context.WithCancel(ctx)
482 defer primaryCancel()
483 go startRacer(primaryCtx, true)
484
485
486 fallbackTimer := time.NewTimer(sd.fallbackDelay())
487 defer fallbackTimer.Stop()
488
489 for {
490 select {
491 case <-fallbackTimer.C:
492 fallbackCtx, fallbackCancel := context.WithCancel(ctx)
493 defer fallbackCancel()
494 go startRacer(fallbackCtx, false)
495
496 case res := <-results:
497 if res.error == nil {
498 return res.Conn, nil
499 }
500 if res.primary {
501 primary = res
502 } else {
503 fallback = res
504 }
505 if primary.done && fallback.done {
506 return nil, primary.error
507 }
508 if res.primary && fallbackTimer.Stop() {
509
510
511
512
513 fallbackTimer.Reset(0)
514 }
515 }
516 }
517 }
518
519
520
521 func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) {
522 var firstErr error
523
524 for i, ra := range ras {
525 select {
526 case <-ctx.Done():
527 return nil, &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
528 default:
529 }
530
531 dialCtx := ctx
532 if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
533 partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
534 if err != nil {
535
536 if firstErr == nil {
537 firstErr = &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: err}
538 }
539 break
540 }
541 if partialDeadline.Before(deadline) {
542 var cancel context.CancelFunc
543 dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
544 defer cancel()
545 }
546 }
547
548 c, err := sd.dialSingle(dialCtx, ra)
549 if err == nil {
550 return c, nil
551 }
552 if firstErr == nil {
553 firstErr = err
554 }
555 }
556
557 if firstErr == nil {
558 firstErr = &OpError{Op: "dial", Net: sd.network, Source: nil, Addr: nil, Err: errMissingAddress}
559 }
560 return nil, firstErr
561 }
562
563
564
565 func (sd *sysDialer) dialSingle(ctx context.Context, ra Addr) (c Conn, err error) {
566 trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
567 if trace != nil {
568 raStr := ra.String()
569 if trace.ConnectStart != nil {
570 trace.ConnectStart(sd.network, raStr)
571 }
572 if trace.ConnectDone != nil {
573 defer func() { trace.ConnectDone(sd.network, raStr, err) }()
574 }
575 }
576 la := sd.LocalAddr
577 switch ra := ra.(type) {
578 case *TCPAddr:
579 la, _ := la.(*TCPAddr)
580 c, err = sd.dialTCP(ctx, la, ra)
581 case *UDPAddr:
582 la, _ := la.(*UDPAddr)
583 c, err = sd.dialUDP(ctx, la, ra)
584 case *IPAddr:
585 la, _ := la.(*IPAddr)
586 c, err = sd.dialIP(ctx, la, ra)
587 case *UnixAddr:
588 la, _ := la.(*UnixAddr)
589 c, err = sd.dialUnix(ctx, la, ra)
590 default:
591 return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: sd.address}}
592 }
593 if err != nil {
594 return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: err}
595 }
596 return c, nil
597 }
598
599
600 type ListenConfig struct {
601
602
603
604
605
606
607 Control func(network, address string, c syscall.RawConn) error
608
609
610
611
612
613
614
615 KeepAlive time.Duration
616 }
617
618
619
620
621
622 func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) {
623 addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
624 if err != nil {
625 return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
626 }
627 sl := &sysListener{
628 ListenConfig: *lc,
629 network: network,
630 address: address,
631 }
632 var l Listener
633 la := addrs.first(isIPv4)
634 switch la := la.(type) {
635 case *TCPAddr:
636 l, err = sl.listenTCP(ctx, la)
637 case *UnixAddr:
638 l, err = sl.listenUnix(ctx, la)
639 default:
640 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
641 }
642 if err != nil {
643 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err}
644 }
645 return l, nil
646 }
647
648
649
650
651
652 func (lc *ListenConfig) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) {
653 addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
654 if err != nil {
655 return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
656 }
657 sl := &sysListener{
658 ListenConfig: *lc,
659 network: network,
660 address: address,
661 }
662 var c PacketConn
663 la := addrs.first(isIPv4)
664 switch la := la.(type) {
665 case *UDPAddr:
666 c, err = sl.listenUDP(ctx, la)
667 case *IPAddr:
668 c, err = sl.listenIP(ctx, la)
669 case *UnixAddr:
670 c, err = sl.listenUnixgram(ctx, la)
671 default:
672 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
673 }
674 if err != nil {
675 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err}
676 }
677 return c, nil
678 }
679
680
681 type sysListener struct {
682 ListenConfig
683 network, address string
684 }
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707 func Listen(network, address string) (Listener, error) {
708 var lc ListenConfig
709 return lc.Listen(context.Background(), network, address)
710 }
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737 func ListenPacket(network, address string) (PacketConn, error) {
738 var lc ListenConfig
739 return lc.ListenPacket(context.Background(), network, address)
740 }
741
View as plain text