Source file
src/net/lookup.go
1
2
3
4
5 package net
6
7 import (
8 "context"
9 "errors"
10 "internal/nettrace"
11 "internal/singleflight"
12 "net/netip"
13 "sync"
14
15 "golang.org/x/net/dns/dnsmessage"
16 )
17
18
19
20
21
22
23
24
25 var protocols = map[string]int{
26 "icmp": 1,
27 "igmp": 2,
28 "tcp": 6,
29 "udp": 17,
30 "ipv6-icmp": 58,
31 }
32
33
34
35
36
37
38
39 var services = map[string]map[string]int{
40 "udp": {
41 "domain": 53,
42 },
43 "tcp": {
44 "ftp": 21,
45 "ftps": 990,
46 "gopher": 70,
47 "http": 80,
48 "https": 443,
49 "imap2": 143,
50 "imap3": 220,
51 "imaps": 993,
52 "pop3": 110,
53 "pop3s": 995,
54 "smtp": 25,
55 "ssh": 22,
56 "telnet": 23,
57 },
58 }
59
60
61
62 var dnsWaitGroup sync.WaitGroup
63
64 const maxProtoLength = len("RSVP-E2E-IGNORE") + 10
65
66 func lookupProtocolMap(name string) (int, error) {
67 var lowerProtocol [maxProtoLength]byte
68 n := copy(lowerProtocol[:], name)
69 lowerASCIIBytes(lowerProtocol[:n])
70 proto, found := protocols[string(lowerProtocol[:n])]
71 if !found || n != len(name) {
72 return 0, &AddrError{Err: "unknown IP protocol specified", Addr: name}
73 }
74 return proto, nil
75 }
76
77
78
79
80
81
82 const maxPortBufSize = len("mobility-header") + 10
83
84 func lookupPortMap(network, service string) (port int, error error) {
85 switch network {
86 case "tcp4", "tcp6":
87 network = "tcp"
88 case "udp4", "udp6":
89 network = "udp"
90 }
91
92 if m, ok := services[network]; ok {
93 var lowerService [maxPortBufSize]byte
94 n := copy(lowerService[:], service)
95 lowerASCIIBytes(lowerService[:n])
96 if port, ok := m[string(lowerService[:n])]; ok && n == len(service) {
97 return port, nil
98 }
99 }
100 return 0, &AddrError{Err: "unknown port", Addr: network + "/" + service}
101 }
102
103
104
105 func ipVersion(network string) byte {
106 if network == "" {
107 return 0
108 }
109 n := network[len(network)-1]
110 if n != '4' && n != '6' {
111 n = 0
112 }
113 return n
114 }
115
116
117
118 var DefaultResolver = &Resolver{}
119
120
121
122
123 type Resolver struct {
124
125
126
127 PreferGo bool
128
129
130
131
132
133
134
135
136
137 StrictErrors bool
138
139
140
141
142
143
144
145
146
147
148
149
150 Dial func(ctx context.Context, network, address string) (Conn, error)
151
152
153
154
155 lookupGroup singleflight.Group
156
157
158
159 }
160
161 func (r *Resolver) preferGo() bool { return r != nil && r.PreferGo }
162 func (r *Resolver) strictErrors() bool { return r != nil && r.StrictErrors }
163
164 func (r *Resolver) getLookupGroup() *singleflight.Group {
165 if r == nil {
166 return &DefaultResolver.lookupGroup
167 }
168 return &r.lookupGroup
169 }
170
171
172
173
174
175
176 func LookupHost(host string) (addrs []string, err error) {
177 return DefaultResolver.LookupHost(context.Background(), host)
178 }
179
180
181
182 func (r *Resolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) {
183
184
185 if host == "" {
186 return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
187 }
188 if ip, _ := parseIPZone(host); ip != nil {
189 return []string{host}, nil
190 }
191 return r.lookupHost(ctx, host)
192 }
193
194
195
196 func LookupIP(host string) ([]IP, error) {
197 addrs, err := DefaultResolver.LookupIPAddr(context.Background(), host)
198 if err != nil {
199 return nil, err
200 }
201 ips := make([]IP, len(addrs))
202 for i, ia := range addrs {
203 ips[i] = ia.IP
204 }
205 return ips, nil
206 }
207
208
209
210 func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, error) {
211 return r.lookupIPAddr(ctx, "ip", host)
212 }
213
214
215
216
217
218 func (r *Resolver) LookupIP(ctx context.Context, network, host string) ([]IP, error) {
219 afnet, _, err := parseNetwork(ctx, network, false)
220 if err != nil {
221 return nil, err
222 }
223 switch afnet {
224 case "ip", "ip4", "ip6":
225 default:
226 return nil, UnknownNetworkError(network)
227 }
228
229 if host == "" {
230 return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
231 }
232 addrs, err := r.internetAddrList(ctx, afnet, host)
233 if err != nil {
234 return nil, err
235 }
236
237 ips := make([]IP, 0, len(addrs))
238 for _, addr := range addrs {
239 ips = append(ips, addr.(*IPAddr).IP)
240 }
241 return ips, nil
242 }
243
244
245
246
247
248 func (r *Resolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
249
250
251
252
253 ips, err := r.LookupIP(ctx, network, host)
254 if err != nil {
255 return nil, err
256 }
257 ret := make([]netip.Addr, 0, len(ips))
258 for _, ip := range ips {
259 if a, ok := netip.AddrFromSlice(ip); ok {
260 ret = append(ret, a)
261 }
262 }
263 return ret, nil
264 }
265
266
267
268 type onlyValuesCtx struct {
269 context.Context
270 lookupValues context.Context
271 }
272
273 var _ context.Context = (*onlyValuesCtx)(nil)
274
275
276 func (ovc *onlyValuesCtx) Value(key any) any {
277 select {
278 case <-ovc.lookupValues.Done():
279 return nil
280 default:
281 return ovc.lookupValues.Value(key)
282 }
283 }
284
285
286
287
288
289 func withUnexpiredValuesPreserved(lookupCtx context.Context) context.Context {
290 return &onlyValuesCtx{Context: context.Background(), lookupValues: lookupCtx}
291 }
292
293
294
295 func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IPAddr, error) {
296
297
298 if host == "" {
299 return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
300 }
301 if ip, zone := parseIPZone(host); ip != nil {
302 return []IPAddr{{IP: ip, Zone: zone}}, nil
303 }
304 trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
305 if trace != nil && trace.DNSStart != nil {
306 trace.DNSStart(host)
307 }
308
309
310
311 resolverFunc := r.lookupIP
312 if alt, _ := ctx.Value(nettrace.LookupIPAltResolverKey{}).(func(context.Context, string, string) ([]IPAddr, error)); alt != nil {
313 resolverFunc = alt
314 }
315
316
317
318
319
320
321 lookupGroupCtx, lookupGroupCancel := context.WithCancel(withUnexpiredValuesPreserved(ctx))
322
323 lookupKey := network + "\000" + host
324 dnsWaitGroup.Add(1)
325 ch := r.getLookupGroup().DoChan(lookupKey, func() (any, error) {
326 return testHookLookupIP(lookupGroupCtx, resolverFunc, network, host)
327 })
328
329 dnsWaitGroupDone := func(ch <-chan singleflight.Result, cancelFn context.CancelFunc) {
330 <-ch
331 dnsWaitGroup.Done()
332 cancelFn()
333 }
334 select {
335 case <-ctx.Done():
336
337
338
339
340
341
342
343 if r.getLookupGroup().ForgetUnshared(lookupKey) {
344 lookupGroupCancel()
345 go dnsWaitGroupDone(ch, func() {})
346 } else {
347 go dnsWaitGroupDone(ch, lookupGroupCancel)
348 }
349 ctxErr := ctx.Err()
350 err := &DNSError{
351 Err: mapErr(ctxErr).Error(),
352 Name: host,
353 IsTimeout: ctxErr == context.DeadlineExceeded,
354 }
355 if trace != nil && trace.DNSDone != nil {
356 trace.DNSDone(nil, false, err)
357 }
358 return nil, err
359 case r := <-ch:
360 dnsWaitGroup.Done()
361 lookupGroupCancel()
362 err := r.Err
363 if err != nil {
364 if _, ok := err.(*DNSError); !ok {
365 isTimeout := false
366 if err == context.DeadlineExceeded {
367 isTimeout = true
368 } else if terr, ok := err.(timeout); ok {
369 isTimeout = terr.Timeout()
370 }
371 err = &DNSError{
372 Err: err.Error(),
373 Name: host,
374 IsTimeout: isTimeout,
375 }
376 }
377 }
378 if trace != nil && trace.DNSDone != nil {
379 addrs, _ := r.Val.([]IPAddr)
380 trace.DNSDone(ipAddrsEface(addrs), r.Shared, err)
381 }
382 return lookupIPReturn(r.Val, err, r.Shared)
383 }
384 }
385
386
387
388 func lookupIPReturn(addrsi any, err error, shared bool) ([]IPAddr, error) {
389 if err != nil {
390 return nil, err
391 }
392 addrs := addrsi.([]IPAddr)
393 if shared {
394 clone := make([]IPAddr, len(addrs))
395 copy(clone, addrs)
396 addrs = clone
397 }
398 return addrs, nil
399 }
400
401
402 func ipAddrsEface(addrs []IPAddr) []any {
403 s := make([]any, len(addrs))
404 for i, v := range addrs {
405 s[i] = v
406 }
407 return s
408 }
409
410
411
412
413
414 func LookupPort(network, service string) (port int, err error) {
415 return DefaultResolver.LookupPort(context.Background(), network, service)
416 }
417
418
419 func (r *Resolver) LookupPort(ctx context.Context, network, service string) (port int, err error) {
420 port, needsLookup := parsePort(service)
421 if needsLookup {
422 switch network {
423 case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
424 case "":
425 network = "ip"
426 default:
427 return 0, &AddrError{Err: "unknown network", Addr: network}
428 }
429 port, err = r.lookupPort(ctx, network, service)
430 if err != nil {
431 return 0, err
432 }
433 }
434 if 0 > port || port > 65535 {
435 return 0, &AddrError{Err: "invalid port", Addr: service}
436 }
437 return port, nil
438 }
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456 func LookupCNAME(host string) (cname string, err error) {
457 return DefaultResolver.LookupCNAME(context.Background(), host)
458 }
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473 func (r *Resolver) LookupCNAME(ctx context.Context, host string) (string, error) {
474 cname, err := r.lookupCNAME(ctx, host)
475 if err != nil {
476 return "", err
477 }
478 if !isDomainName(cname) {
479 return "", &DNSError{Err: errMalformedDNSRecordsDetail, Name: host}
480 }
481 return cname, nil
482 }
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498 func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) {
499 return DefaultResolver.LookupSRV(context.Background(), service, proto, name)
500 }
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516 func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
517 cname, addrs, err := r.lookupSRV(ctx, service, proto, name)
518 if err != nil {
519 return "", nil, err
520 }
521 if cname != "" && !isDomainName(cname) {
522 return "", nil, &DNSError{Err: "SRV header name is invalid", Name: name}
523 }
524 filteredAddrs := make([]*SRV, 0, len(addrs))
525 for _, addr := range addrs {
526 if addr == nil {
527 continue
528 }
529 if !isDomainName(addr.Target) {
530 continue
531 }
532 filteredAddrs = append(filteredAddrs, addr)
533 }
534 if len(addrs) != len(filteredAddrs) {
535 return cname, filteredAddrs, &DNSError{Err: errMalformedDNSRecordsDetail, Name: name}
536 }
537 return cname, filteredAddrs, nil
538 }
539
540
541
542
543
544
545
546
547
548
549 func LookupMX(name string) ([]*MX, error) {
550 return DefaultResolver.LookupMX(context.Background(), name)
551 }
552
553
554
555
556
557
558
559 func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) {
560 records, err := r.lookupMX(ctx, name)
561 if err != nil {
562 return nil, err
563 }
564 filteredMX := make([]*MX, 0, len(records))
565 for _, mx := range records {
566 if mx == nil {
567 continue
568 }
569 if !isDomainName(mx.Host) {
570 continue
571 }
572 filteredMX = append(filteredMX, mx)
573 }
574 if len(records) != len(filteredMX) {
575 return filteredMX, &DNSError{Err: errMalformedDNSRecordsDetail, Name: name}
576 }
577 return filteredMX, nil
578 }
579
580
581
582
583
584
585
586
587
588
589 func LookupNS(name string) ([]*NS, error) {
590 return DefaultResolver.LookupNS(context.Background(), name)
591 }
592
593
594
595
596
597
598
599 func (r *Resolver) LookupNS(ctx context.Context, name string) ([]*NS, error) {
600 records, err := r.lookupNS(ctx, name)
601 if err != nil {
602 return nil, err
603 }
604 filteredNS := make([]*NS, 0, len(records))
605 for _, ns := range records {
606 if ns == nil {
607 continue
608 }
609 if !isDomainName(ns.Host) {
610 continue
611 }
612 filteredNS = append(filteredNS, ns)
613 }
614 if len(records) != len(filteredNS) {
615 return filteredNS, &DNSError{Err: errMalformedDNSRecordsDetail, Name: name}
616 }
617 return filteredNS, nil
618 }
619
620
621
622
623
624 func LookupTXT(name string) ([]string, error) {
625 return DefaultResolver.lookupTXT(context.Background(), name)
626 }
627
628
629 func (r *Resolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
630 return r.lookupTXT(ctx, name)
631 }
632
633
634
635
636
637
638
639
640
641
642
643
644
645 func LookupAddr(addr string) (names []string, err error) {
646 return DefaultResolver.LookupAddr(context.Background(), addr)
647 }
648
649
650
651
652
653
654
655 func (r *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error) {
656 names, err := r.lookupAddr(ctx, addr)
657 if err != nil {
658 return nil, err
659 }
660 filteredNames := make([]string, 0, len(names))
661 for _, name := range names {
662 if isDomainName(name) {
663 filteredNames = append(filteredNames, name)
664 }
665 }
666 if len(names) != len(filteredNames) {
667 return filteredNames, &DNSError{Err: errMalformedDNSRecordsDetail, Name: addr}
668 }
669 return filteredNames, nil
670 }
671
672
673
674
675 var errMalformedDNSRecordsDetail = "DNS response contained records which contain invalid names"
676
677
678
679
680 func (r *Resolver) dial(ctx context.Context, network, server string) (Conn, error) {
681
682
683
684
685
686 var c Conn
687 var err error
688 if r != nil && r.Dial != nil {
689 c, err = r.Dial(ctx, network, server)
690 } else {
691 var d Dialer
692 c, err = d.DialContext(ctx, network, server)
693 }
694 if err != nil {
695 return nil, mapErr(err)
696 }
697 return c, nil
698 }
699
700
701
702
703
704
705
706
707
708
709 func (r *Resolver) goLookupSRV(ctx context.Context, service, proto, name string) (target string, srvs []*SRV, err error) {
710 if service == "" && proto == "" {
711 target = name
712 } else {
713 target = "_" + service + "._" + proto + "." + name
714 }
715 p, server, err := r.lookup(ctx, target, dnsmessage.TypeSRV, nil)
716 if err != nil {
717 return "", nil, err
718 }
719 var cname dnsmessage.Name
720 for {
721 h, err := p.AnswerHeader()
722 if err == dnsmessage.ErrSectionDone {
723 break
724 }
725 if err != nil {
726 return "", nil, &DNSError{
727 Err: "cannot unmarshal DNS message",
728 Name: name,
729 Server: server,
730 }
731 }
732 if h.Type != dnsmessage.TypeSRV {
733 if err := p.SkipAnswer(); err != nil {
734 return "", nil, &DNSError{
735 Err: "cannot unmarshal DNS message",
736 Name: name,
737 Server: server,
738 }
739 }
740 continue
741 }
742 if cname.Length == 0 && h.Name.Length != 0 {
743 cname = h.Name
744 }
745 srv, err := p.SRVResource()
746 if err != nil {
747 return "", nil, &DNSError{
748 Err: "cannot unmarshal DNS message",
749 Name: name,
750 Server: server,
751 }
752 }
753 srvs = append(srvs, &SRV{Target: srv.Target.String(), Port: srv.Port, Priority: srv.Priority, Weight: srv.Weight})
754 }
755 byPriorityWeight(srvs).sort()
756 return cname.String(), srvs, nil
757 }
758
759
760 func (r *Resolver) goLookupMX(ctx context.Context, name string) ([]*MX, error) {
761 p, server, err := r.lookup(ctx, name, dnsmessage.TypeMX, nil)
762 if err != nil {
763 return nil, err
764 }
765 var mxs []*MX
766 for {
767 h, err := p.AnswerHeader()
768 if err == dnsmessage.ErrSectionDone {
769 break
770 }
771 if err != nil {
772 return nil, &DNSError{
773 Err: "cannot unmarshal DNS message",
774 Name: name,
775 Server: server,
776 }
777 }
778 if h.Type != dnsmessage.TypeMX {
779 if err := p.SkipAnswer(); err != nil {
780 return nil, &DNSError{
781 Err: "cannot unmarshal DNS message",
782 Name: name,
783 Server: server,
784 }
785 }
786 continue
787 }
788 mx, err := p.MXResource()
789 if err != nil {
790 return nil, &DNSError{
791 Err: "cannot unmarshal DNS message",
792 Name: name,
793 Server: server,
794 }
795 }
796 mxs = append(mxs, &MX{Host: mx.MX.String(), Pref: mx.Pref})
797
798 }
799 byPref(mxs).sort()
800 return mxs, nil
801 }
802
803
804 func (r *Resolver) goLookupNS(ctx context.Context, name string) ([]*NS, error) {
805 p, server, err := r.lookup(ctx, name, dnsmessage.TypeNS, nil)
806 if err != nil {
807 return nil, err
808 }
809 var nss []*NS
810 for {
811 h, err := p.AnswerHeader()
812 if err == dnsmessage.ErrSectionDone {
813 break
814 }
815 if err != nil {
816 return nil, &DNSError{
817 Err: "cannot unmarshal DNS message",
818 Name: name,
819 Server: server,
820 }
821 }
822 if h.Type != dnsmessage.TypeNS {
823 if err := p.SkipAnswer(); err != nil {
824 return nil, &DNSError{
825 Err: "cannot unmarshal DNS message",
826 Name: name,
827 Server: server,
828 }
829 }
830 continue
831 }
832 ns, err := p.NSResource()
833 if err != nil {
834 return nil, &DNSError{
835 Err: "cannot unmarshal DNS message",
836 Name: name,
837 Server: server,
838 }
839 }
840 nss = append(nss, &NS{Host: ns.NS.String()})
841 }
842 return nss, nil
843 }
844
845
846 func (r *Resolver) goLookupTXT(ctx context.Context, name string) ([]string, error) {
847 p, server, err := r.lookup(ctx, name, dnsmessage.TypeTXT, nil)
848 if err != nil {
849 return nil, err
850 }
851 var txts []string
852 for {
853 h, err := p.AnswerHeader()
854 if err == dnsmessage.ErrSectionDone {
855 break
856 }
857 if err != nil {
858 return nil, &DNSError{
859 Err: "cannot unmarshal DNS message",
860 Name: name,
861 Server: server,
862 }
863 }
864 if h.Type != dnsmessage.TypeTXT {
865 if err := p.SkipAnswer(); err != nil {
866 return nil, &DNSError{
867 Err: "cannot unmarshal DNS message",
868 Name: name,
869 Server: server,
870 }
871 }
872 continue
873 }
874 txt, err := p.TXTResource()
875 if err != nil {
876 return nil, &DNSError{
877 Err: "cannot unmarshal DNS message",
878 Name: name,
879 Server: server,
880 }
881 }
882
883
884
885 n := 0
886 for _, s := range txt.TXT {
887 n += len(s)
888 }
889 txtJoin := make([]byte, 0, n)
890 for _, s := range txt.TXT {
891 txtJoin = append(txtJoin, s...)
892 }
893 if len(txts) == 0 {
894 txts = make([]string, 0, 1)
895 }
896 txts = append(txts, string(txtJoin))
897 }
898 return txts, nil
899 }
900
901 func parseCNAMEFromResources(resources []dnsmessage.Resource) (string, error) {
902 if len(resources) == 0 {
903 return "", errors.New("no CNAME record received")
904 }
905 c, ok := resources[0].Body.(*dnsmessage.CNAMEResource)
906 if !ok {
907 return "", errors.New("could not parse CNAME record")
908 }
909 return c.CNAME.String(), nil
910 }
911
View as plain text