Source file
src/net/dnsclient_unix.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package net
16
17 import (
18 "context"
19 "errors"
20 "internal/itoa"
21 "io"
22 "os"
23 "runtime"
24 "sync"
25 "sync/atomic"
26 "time"
27
28 "golang.org/x/net/dns/dnsmessage"
29 )
30
31 const (
32
33 useTCPOnly = true
34 useUDPOrTCP = false
35
36
37
38 maxDNSPacketSize = 1232
39 )
40
41 var (
42 errLameReferral = errors.New("lame referral")
43 errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message")
44 errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message")
45 errServerMisbehaving = errors.New("server misbehaving")
46 errInvalidDNSResponse = errors.New("invalid DNS response")
47 errNoAnswerFromDNSServer = errors.New("no answer from DNS server")
48
49
50
51
52 errServerTemporarilyMisbehaving = errors.New("server misbehaving")
53 )
54
55 func newRequest(q dnsmessage.Question, ad bool) (id uint16, udpReq, tcpReq []byte, err error) {
56 id = uint16(randInt())
57 b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true, AuthenticData: ad})
58 if err := b.StartQuestions(); err != nil {
59 return 0, nil, nil, err
60 }
61 if err := b.Question(q); err != nil {
62 return 0, nil, nil, err
63 }
64
65
66 if err := b.StartAdditionals(); err != nil {
67 return 0, nil, nil, err
68 }
69 var rh dnsmessage.ResourceHeader
70 if err := rh.SetEDNS0(maxDNSPacketSize, dnsmessage.RCodeSuccess, false); err != nil {
71 return 0, nil, nil, err
72 }
73 if err := b.OPTResource(rh, dnsmessage.OPTResource{}); err != nil {
74 return 0, nil, nil, err
75 }
76
77 tcpReq, err = b.Finish()
78 if err != nil {
79 return 0, nil, nil, err
80 }
81 udpReq = tcpReq[2:]
82 l := len(tcpReq) - 2
83 tcpReq[0] = byte(l >> 8)
84 tcpReq[1] = byte(l)
85 return id, udpReq, tcpReq, nil
86 }
87
88 func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
89 if !respHdr.Response {
90 return false
91 }
92 if reqID != respHdr.ID {
93 return false
94 }
95 if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
96 return false
97 }
98 return true
99 }
100
101 func dnsPacketRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
102 if _, err := c.Write(b); err != nil {
103 return dnsmessage.Parser{}, dnsmessage.Header{}, err
104 }
105
106 b = make([]byte, maxDNSPacketSize)
107 for {
108 n, err := c.Read(b)
109 if err != nil {
110 return dnsmessage.Parser{}, dnsmessage.Header{}, err
111 }
112 var p dnsmessage.Parser
113
114
115
116 h, err := p.Start(b[:n])
117 if err != nil {
118 continue
119 }
120 q, err := p.Question()
121 if err != nil || !checkResponse(id, query, h, q) {
122 continue
123 }
124 return p, h, nil
125 }
126 }
127
128 func dnsStreamRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
129 if _, err := c.Write(b); err != nil {
130 return dnsmessage.Parser{}, dnsmessage.Header{}, err
131 }
132
133 b = make([]byte, 1280)
134 if _, err := io.ReadFull(c, b[:2]); err != nil {
135 return dnsmessage.Parser{}, dnsmessage.Header{}, err
136 }
137 l := int(b[0])<<8 | int(b[1])
138 if l > len(b) {
139 b = make([]byte, l)
140 }
141 n, err := io.ReadFull(c, b[:l])
142 if err != nil {
143 return dnsmessage.Parser{}, dnsmessage.Header{}, err
144 }
145 var p dnsmessage.Parser
146 h, err := p.Start(b[:n])
147 if err != nil {
148 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
149 }
150 q, err := p.Question()
151 if err != nil {
152 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
153 }
154 if !checkResponse(id, query, h, q) {
155 return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
156 }
157 return p, h, nil
158 }
159
160
161 func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration, useTCP, ad bool) (dnsmessage.Parser, dnsmessage.Header, error) {
162 q.Class = dnsmessage.ClassINET
163 id, udpReq, tcpReq, err := newRequest(q, ad)
164 if err != nil {
165 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
166 }
167 var networks []string
168 if useTCP {
169 networks = []string{"tcp"}
170 } else {
171 networks = []string{"udp", "tcp"}
172 }
173 for _, network := range networks {
174 ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
175 defer cancel()
176
177 c, err := r.dial(ctx, network, server)
178 if err != nil {
179 return dnsmessage.Parser{}, dnsmessage.Header{}, err
180 }
181 if d, ok := ctx.Deadline(); ok && !d.IsZero() {
182 c.SetDeadline(d)
183 }
184 var p dnsmessage.Parser
185 var h dnsmessage.Header
186 if _, ok := c.(PacketConn); ok {
187 p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
188 } else {
189 p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
190 }
191 c.Close()
192 if err != nil {
193 return dnsmessage.Parser{}, dnsmessage.Header{}, mapErr(err)
194 }
195 if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
196 return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
197 }
198 if h.Truncated {
199 continue
200 }
201 return p, h, nil
202 }
203 return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
204 }
205
206
207 func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
208 if h.RCode == dnsmessage.RCodeNameError {
209 return errNoSuchHost
210 }
211
212 _, err := p.AnswerHeader()
213 if err != nil && err != dnsmessage.ErrSectionDone {
214 return errCannotUnmarshalDNSMessage
215 }
216
217
218
219 if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
220 return errLameReferral
221 }
222
223 if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
224
225
226
227
228
229 if h.RCode == dnsmessage.RCodeServerFailure {
230 return errServerTemporarilyMisbehaving
231 }
232 return errServerMisbehaving
233 }
234
235 return nil
236 }
237
238 func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
239 for {
240 h, err := p.AnswerHeader()
241 if err == dnsmessage.ErrSectionDone {
242 return errNoSuchHost
243 }
244 if err != nil {
245 return errCannotUnmarshalDNSMessage
246 }
247 if h.Type == qtype {
248 return nil
249 }
250 if err := p.SkipAnswer(); err != nil {
251 return errCannotUnmarshalDNSMessage
252 }
253 }
254 }
255
256
257
258 func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
259 var lastErr error
260 serverOffset := cfg.serverOffset()
261 sLen := uint32(len(cfg.servers))
262
263 n, err := dnsmessage.NewName(name)
264 if err != nil {
265 return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
266 }
267 q := dnsmessage.Question{
268 Name: n,
269 Type: qtype,
270 Class: dnsmessage.ClassINET,
271 }
272
273 for i := 0; i < cfg.attempts; i++ {
274 for j := uint32(0); j < sLen; j++ {
275 server := cfg.servers[(serverOffset+j)%sLen]
276
277 p, h, err := r.exchange(ctx, server, q, cfg.timeout, cfg.useTCP, cfg.trustAD)
278 if err != nil {
279 dnsErr := &DNSError{
280 Err: err.Error(),
281 Name: name,
282 Server: server,
283 }
284 if nerr, ok := err.(Error); ok && nerr.Timeout() {
285 dnsErr.IsTimeout = true
286 }
287
288
289 if _, ok := err.(*OpError); ok {
290 dnsErr.IsTemporary = true
291 }
292 lastErr = dnsErr
293 continue
294 }
295
296 if err := checkHeader(&p, h); err != nil {
297 dnsErr := &DNSError{
298 Err: err.Error(),
299 Name: name,
300 Server: server,
301 }
302 if err == errServerTemporarilyMisbehaving {
303 dnsErr.IsTemporary = true
304 }
305 if err == errNoSuchHost {
306
307
308
309 dnsErr.IsNotFound = true
310 return p, server, dnsErr
311 }
312 lastErr = dnsErr
313 continue
314 }
315
316 err = skipToAnswer(&p, qtype)
317 if err == nil {
318 return p, server, nil
319 }
320 lastErr = &DNSError{
321 Err: err.Error(),
322 Name: name,
323 Server: server,
324 }
325 if err == errNoSuchHost {
326
327
328
329 lastErr.(*DNSError).IsNotFound = true
330 return p, server, lastErr
331 }
332 }
333 }
334 return dnsmessage.Parser{}, "", lastErr
335 }
336
337
338 type resolverConfig struct {
339 initOnce sync.Once
340
341
342
343 ch chan struct{}
344 lastChecked time.Time
345
346 dnsConfig atomic.Pointer[dnsConfig]
347 }
348
349 var resolvConf resolverConfig
350
351 func getSystemDNSConfig() *dnsConfig {
352 resolvConf.tryUpdate("/etc/resolv.conf")
353 return resolvConf.dnsConfig.Load()
354 }
355
356
357 func (conf *resolverConfig) init() {
358
359
360 conf.dnsConfig.Store(dnsReadConfig("/etc/resolv.conf"))
361 conf.lastChecked = time.Now()
362
363
364
365 conf.ch = make(chan struct{}, 1)
366 }
367
368
369
370
371 func (conf *resolverConfig) tryUpdate(name string) {
372 conf.initOnce.Do(conf.init)
373
374 if conf.dnsConfig.Load().noReload {
375 return
376 }
377
378
379 if !conf.tryAcquireSema() {
380 return
381 }
382 defer conf.releaseSema()
383
384 now := time.Now()
385 if conf.lastChecked.After(now.Add(-5 * time.Second)) {
386 return
387 }
388 conf.lastChecked = now
389
390 switch runtime.GOOS {
391 case "windows":
392
393
394
395
396
397 default:
398 var mtime time.Time
399 if fi, err := os.Stat(name); err == nil {
400 mtime = fi.ModTime()
401 }
402 if mtime.Equal(conf.dnsConfig.Load().mtime) {
403 return
404 }
405 }
406
407 dnsConf := dnsReadConfig(name)
408 conf.dnsConfig.Store(dnsConf)
409 }
410
411 func (conf *resolverConfig) tryAcquireSema() bool {
412 select {
413 case conf.ch <- struct{}{}:
414 return true
415 default:
416 return false
417 }
418 }
419
420 func (conf *resolverConfig) releaseSema() {
421 <-conf.ch
422 }
423
424 func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type, conf *dnsConfig) (dnsmessage.Parser, string, error) {
425 if !isDomainName(name) {
426
427
428
429
430
431 return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
432 }
433
434 if conf == nil {
435 conf = getSystemDNSConfig()
436 }
437
438 var (
439 p dnsmessage.Parser
440 server string
441 err error
442 )
443 for _, fqdn := range conf.nameList(name) {
444 p, server, err = r.tryOneName(ctx, conf, fqdn, qtype)
445 if err == nil {
446 break
447 }
448 if nerr, ok := err.(Error); ok && nerr.Temporary() && r.strictErrors() {
449
450
451 break
452 }
453 }
454 if err == nil {
455 return p, server, nil
456 }
457 if err, ok := err.(*DNSError); ok {
458
459
460
461 err.Name = name
462 }
463 return dnsmessage.Parser{}, "", err
464 }
465
466
467
468
469
470 func avoidDNS(name string) bool {
471 if name == "" {
472 return true
473 }
474 if name[len(name)-1] == '.' {
475 name = name[:len(name)-1]
476 }
477 return stringsHasSuffixFold(name, ".onion")
478 }
479
480
481 func (conf *dnsConfig) nameList(name string) []string {
482 if avoidDNS(name) {
483 return nil
484 }
485
486
487 l := len(name)
488 rooted := l > 0 && name[l-1] == '.'
489 if l > 254 || l == 254 && !rooted {
490 return nil
491 }
492
493
494 if rooted {
495 return []string{name}
496 }
497
498 hasNdots := count(name, '.') >= conf.ndots
499 name += "."
500 l++
501
502
503 names := make([]string, 0, 1+len(conf.search))
504
505 if hasNdots {
506 names = append(names, name)
507 }
508
509 for _, suffix := range conf.search {
510 if l+len(suffix) <= 254 {
511 names = append(names, name+suffix)
512 }
513 }
514
515 if !hasNdots {
516 names = append(names, name)
517 }
518 return names
519 }
520
521
522
523
524 type hostLookupOrder int
525
526 const (
527
528 hostLookupCgo hostLookupOrder = iota
529 hostLookupFilesDNS
530 hostLookupDNSFiles
531 hostLookupFiles
532 hostLookupDNS
533 )
534
535 var lookupOrderName = map[hostLookupOrder]string{
536 hostLookupCgo: "cgo",
537 hostLookupFilesDNS: "files,dns",
538 hostLookupDNSFiles: "dns,files",
539 hostLookupFiles: "files",
540 hostLookupDNS: "dns",
541 }
542
543 func (o hostLookupOrder) String() string {
544 if s, ok := lookupOrderName[o]; ok {
545 return s
546 }
547 return "hostLookupOrder=" + itoa.Itoa(int(o)) + "??"
548 }
549
550 func (r *Resolver) goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder, conf *dnsConfig) (addrs []string, err error) {
551 if order == hostLookupFilesDNS || order == hostLookupFiles {
552
553 addrs, _ = lookupStaticHost(name)
554 if len(addrs) > 0 {
555 return
556 }
557
558 if order == hostLookupFiles {
559 return nil, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
560 }
561 }
562 ips, _, err := r.goLookupIPCNAMEOrder(ctx, "ip", name, order, conf)
563 if err != nil {
564 return
565 }
566 addrs = make([]string, 0, len(ips))
567 for _, ip := range ips {
568 addrs = append(addrs, ip.String())
569 }
570 return
571 }
572
573
574 func goLookupIPFiles(name string) (addrs []IPAddr, canonical string) {
575 addr, canonical := lookupStaticHost(name)
576 for _, haddr := range addr {
577 haddr, zone := splitHostZone(haddr)
578 if ip := ParseIP(haddr); ip != nil {
579 addr := IPAddr{IP: ip, Zone: zone}
580 addrs = append(addrs, addr)
581 }
582 }
583 sortByRFC6724(addrs)
584 return addrs, canonical
585 }
586
587
588
589 func (r *Resolver) goLookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
590 order, conf := systemConf().hostLookupOrder(r, host)
591 addrs, _, err = r.goLookupIPCNAMEOrder(ctx, network, host, order, conf)
592 return
593 }
594
595 func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name string, order hostLookupOrder, conf *dnsConfig) (addrs []IPAddr, cname dnsmessage.Name, err error) {
596 if order == hostLookupFilesDNS || order == hostLookupFiles {
597 var canonical string
598 addrs, canonical = goLookupIPFiles(name)
599
600 if len(addrs) > 0 {
601 var err error
602 cname, err = dnsmessage.NewName(canonical)
603 if err != nil {
604 return nil, dnsmessage.Name{}, err
605 }
606 return addrs, cname, nil
607 }
608
609 if order == hostLookupFiles {
610 return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
611 }
612 }
613
614 if !isDomainName(name) {
615
616 return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
617 }
618 type result struct {
619 p dnsmessage.Parser
620 server string
621 error
622 }
623
624 if conf == nil {
625 conf = getSystemDNSConfig()
626 }
627
628 lane := make(chan result, 1)
629 qtypes := []dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA}
630 if network == "CNAME" {
631 qtypes = append(qtypes, dnsmessage.TypeCNAME)
632 }
633 switch ipVersion(network) {
634 case '4':
635 qtypes = []dnsmessage.Type{dnsmessage.TypeA}
636 case '6':
637 qtypes = []dnsmessage.Type{dnsmessage.TypeAAAA}
638 }
639 var queryFn func(fqdn string, qtype dnsmessage.Type)
640 var responseFn func(fqdn string, qtype dnsmessage.Type) result
641 if conf.singleRequest {
642 queryFn = func(fqdn string, qtype dnsmessage.Type) {}
643 responseFn = func(fqdn string, qtype dnsmessage.Type) result {
644 dnsWaitGroup.Add(1)
645 defer dnsWaitGroup.Done()
646 p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
647 return result{p, server, err}
648 }
649 } else {
650 queryFn = func(fqdn string, qtype dnsmessage.Type) {
651 dnsWaitGroup.Add(1)
652 go func(qtype dnsmessage.Type) {
653 p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
654 lane <- result{p, server, err}
655 dnsWaitGroup.Done()
656 }(qtype)
657 }
658 responseFn = func(fqdn string, qtype dnsmessage.Type) result {
659 return <-lane
660 }
661 }
662 var lastErr error
663 for _, fqdn := range conf.nameList(name) {
664 for _, qtype := range qtypes {
665 queryFn(fqdn, qtype)
666 }
667 hitStrictError := false
668 for _, qtype := range qtypes {
669 result := responseFn(fqdn, qtype)
670 if result.error != nil {
671 if nerr, ok := result.error.(Error); ok && nerr.Temporary() && r.strictErrors() {
672
673 hitStrictError = true
674 lastErr = result.error
675 } else if lastErr == nil || fqdn == name+"." {
676
677 lastErr = result.error
678 }
679 continue
680 }
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697 loop:
698 for {
699 h, err := result.p.AnswerHeader()
700 if err != nil && err != dnsmessage.ErrSectionDone {
701 lastErr = &DNSError{
702 Err: "cannot marshal DNS message",
703 Name: name,
704 Server: result.server,
705 }
706 }
707 if err != nil {
708 break
709 }
710 switch h.Type {
711 case dnsmessage.TypeA:
712 a, err := result.p.AResource()
713 if err != nil {
714 lastErr = &DNSError{
715 Err: "cannot marshal DNS message",
716 Name: name,
717 Server: result.server,
718 }
719 break loop
720 }
721 addrs = append(addrs, IPAddr{IP: IP(a.A[:])})
722 if cname.Length == 0 && h.Name.Length != 0 {
723 cname = h.Name
724 }
725
726 case dnsmessage.TypeAAAA:
727 aaaa, err := result.p.AAAAResource()
728 if err != nil {
729 lastErr = &DNSError{
730 Err: "cannot marshal DNS message",
731 Name: name,
732 Server: result.server,
733 }
734 break loop
735 }
736 addrs = append(addrs, IPAddr{IP: IP(aaaa.AAAA[:])})
737 if cname.Length == 0 && h.Name.Length != 0 {
738 cname = h.Name
739 }
740
741 case dnsmessage.TypeCNAME:
742 c, err := result.p.CNAMEResource()
743 if err != nil {
744 lastErr = &DNSError{
745 Err: "cannot marshal DNS message",
746 Name: name,
747 Server: result.server,
748 }
749 break loop
750 }
751 if cname.Length == 0 && c.CNAME.Length > 0 {
752 cname = c.CNAME
753 }
754
755 default:
756 if err := result.p.SkipAnswer(); err != nil {
757 lastErr = &DNSError{
758 Err: "cannot marshal DNS message",
759 Name: name,
760 Server: result.server,
761 }
762 break loop
763 }
764 continue
765 }
766 }
767 }
768 if hitStrictError {
769
770
771
772 addrs = nil
773 break
774 }
775 if len(addrs) > 0 || network == "CNAME" && cname.Length > 0 {
776 break
777 }
778 }
779 if lastErr, ok := lastErr.(*DNSError); ok {
780
781
782
783 lastErr.Name = name
784 }
785 sortByRFC6724(addrs)
786 if len(addrs) == 0 && !(network == "CNAME" && cname.Length > 0) {
787 if order == hostLookupDNSFiles {
788 var canonical string
789 addrs, canonical = goLookupIPFiles(name)
790 if len(addrs) > 0 {
791 var err error
792 cname, err = dnsmessage.NewName(canonical)
793 if err != nil {
794 return nil, dnsmessage.Name{}, err
795 }
796 return addrs, cname, nil
797 }
798 }
799 if lastErr != nil {
800 return nil, dnsmessage.Name{}, lastErr
801 }
802 }
803 return addrs, cname, nil
804 }
805
806
807 func (r *Resolver) goLookupCNAME(ctx context.Context, host string, order hostLookupOrder, conf *dnsConfig) (string, error) {
808 _, cname, err := r.goLookupIPCNAMEOrder(ctx, "CNAME", host, order, conf)
809 return cname.String(), err
810 }
811
812
813
814
815
816
817 func (r *Resolver) goLookupPTR(ctx context.Context, addr string, conf *dnsConfig) ([]string, error) {
818 names := lookupStaticAddr(addr)
819 if len(names) > 0 {
820 return names, nil
821 }
822 arpa, err := reverseaddr(addr)
823 if err != nil {
824 return nil, err
825 }
826 p, server, err := r.lookup(ctx, arpa, dnsmessage.TypePTR, conf)
827 if err != nil {
828 return nil, err
829 }
830 var ptrs []string
831 for {
832 h, err := p.AnswerHeader()
833 if err == dnsmessage.ErrSectionDone {
834 break
835 }
836 if err != nil {
837 return nil, &DNSError{
838 Err: "cannot marshal DNS message",
839 Name: addr,
840 Server: server,
841 }
842 }
843 if h.Type != dnsmessage.TypePTR {
844 err := p.SkipAnswer()
845 if err != nil {
846 return nil, &DNSError{
847 Err: "cannot marshal DNS message",
848 Name: addr,
849 Server: server,
850 }
851 }
852 continue
853 }
854 ptr, err := p.PTRResource()
855 if err != nil {
856 return nil, &DNSError{
857 Err: "cannot marshal DNS message",
858 Name: addr,
859 Server: server,
860 }
861 }
862 ptrs = append(ptrs, ptr.PTR.String())
863
864 }
865 return ptrs, nil
866 }
867
View as plain text