Source file src/net/dnsclient_unix.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build !js
     6  
     7  // DNS client: see RFC 1035.
     8  // Has to be linked into package net for Dial.
     9  
    10  // TODO(rsc):
    11  //	Could potentially handle many outstanding lookups faster.
    12  //	Random UDP source port (net.Dial should do that for us).
    13  //	Random request IDs.
    14  
    15  package net
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"internal/itoa"
    21  	"io"
    22  	"os"
    23  	"runtime"
    24  	"sync"
    25  	"sync/atomic"
    26  	"time"
    27  
    28  	"golang.org/x/net/dns/dnsmessage"
    29  )
    30  
    31  const (
    32  	// to be used as a useTCP parameter to exchange
    33  	useTCPOnly  = true
    34  	useUDPOrTCP = false
    35  
    36  	// Maximum DNS packet size.
    37  	// Value taken from https://dnsflagday.net/2020/.
    38  	maxDNSPacketSize = 1232
    39  )
    40  
    41  var (
    42  	errLameReferral              = errors.New("lame referral")
    43  	errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message")
    44  	errCannotMarshalDNSMessage   = errors.New("cannot marshal DNS message")
    45  	errServerMisbehaving         = errors.New("server misbehaving")
    46  	errInvalidDNSResponse        = errors.New("invalid DNS response")
    47  	errNoAnswerFromDNSServer     = errors.New("no answer from DNS server")
    48  
    49  	// errServerTemporarilyMisbehaving is like errServerMisbehaving, except
    50  	// that when it gets translated to a DNSError, the IsTemporary field
    51  	// gets set to true.
    52  	errServerTemporarilyMisbehaving = errors.New("server misbehaving")
    53  )
    54  
    55  func newRequest(q dnsmessage.Question, ad bool) (id uint16, udpReq, tcpReq []byte, err error) {
    56  	id = uint16(randInt())
    57  	b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true, AuthenticData: ad})
    58  	if err := b.StartQuestions(); err != nil {
    59  		return 0, nil, nil, err
    60  	}
    61  	if err := b.Question(q); err != nil {
    62  		return 0, nil, nil, err
    63  	}
    64  
    65  	// Accept packets up to maxDNSPacketSize.  RFC 6891.
    66  	if err := b.StartAdditionals(); err != nil {
    67  		return 0, nil, nil, err
    68  	}
    69  	var rh dnsmessage.ResourceHeader
    70  	if err := rh.SetEDNS0(maxDNSPacketSize, dnsmessage.RCodeSuccess, false); err != nil {
    71  		return 0, nil, nil, err
    72  	}
    73  	if err := b.OPTResource(rh, dnsmessage.OPTResource{}); err != nil {
    74  		return 0, nil, nil, err
    75  	}
    76  
    77  	tcpReq, err = b.Finish()
    78  	if err != nil {
    79  		return 0, nil, nil, err
    80  	}
    81  	udpReq = tcpReq[2:]
    82  	l := len(tcpReq) - 2
    83  	tcpReq[0] = byte(l >> 8)
    84  	tcpReq[1] = byte(l)
    85  	return id, udpReq, tcpReq, nil
    86  }
    87  
    88  func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
    89  	if !respHdr.Response {
    90  		return false
    91  	}
    92  	if reqID != respHdr.ID {
    93  		return false
    94  	}
    95  	if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
    96  		return false
    97  	}
    98  	return true
    99  }
   100  
   101  func dnsPacketRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
   102  	if _, err := c.Write(b); err != nil {
   103  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
   104  	}
   105  
   106  	b = make([]byte, maxDNSPacketSize)
   107  	for {
   108  		n, err := c.Read(b)
   109  		if err != nil {
   110  			return dnsmessage.Parser{}, dnsmessage.Header{}, err
   111  		}
   112  		var p dnsmessage.Parser
   113  		// Ignore invalid responses as they may be malicious
   114  		// forgery attempts. Instead continue waiting until
   115  		// timeout. See golang.org/issue/13281.
   116  		h, err := p.Start(b[:n])
   117  		if err != nil {
   118  			continue
   119  		}
   120  		q, err := p.Question()
   121  		if err != nil || !checkResponse(id, query, h, q) {
   122  			continue
   123  		}
   124  		return p, h, nil
   125  	}
   126  }
   127  
   128  func dnsStreamRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
   129  	if _, err := c.Write(b); err != nil {
   130  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
   131  	}
   132  
   133  	b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
   134  	if _, err := io.ReadFull(c, b[:2]); err != nil {
   135  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
   136  	}
   137  	l := int(b[0])<<8 | int(b[1])
   138  	if l > len(b) {
   139  		b = make([]byte, l)
   140  	}
   141  	n, err := io.ReadFull(c, b[:l])
   142  	if err != nil {
   143  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
   144  	}
   145  	var p dnsmessage.Parser
   146  	h, err := p.Start(b[:n])
   147  	if err != nil {
   148  		return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
   149  	}
   150  	q, err := p.Question()
   151  	if err != nil {
   152  		return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
   153  	}
   154  	if !checkResponse(id, query, h, q) {
   155  		return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
   156  	}
   157  	return p, h, nil
   158  }
   159  
   160  // exchange sends a query on the connection and hopes for a response.
   161  func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration, useTCP, ad bool) (dnsmessage.Parser, dnsmessage.Header, error) {
   162  	q.Class = dnsmessage.ClassINET
   163  	id, udpReq, tcpReq, err := newRequest(q, ad)
   164  	if err != nil {
   165  		return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
   166  	}
   167  	var networks []string
   168  	if useTCP {
   169  		networks = []string{"tcp"}
   170  	} else {
   171  		networks = []string{"udp", "tcp"}
   172  	}
   173  	for _, network := range networks {
   174  		ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
   175  		defer cancel()
   176  
   177  		c, err := r.dial(ctx, network, server)
   178  		if err != nil {
   179  			return dnsmessage.Parser{}, dnsmessage.Header{}, err
   180  		}
   181  		if d, ok := ctx.Deadline(); ok && !d.IsZero() {
   182  			c.SetDeadline(d)
   183  		}
   184  		var p dnsmessage.Parser
   185  		var h dnsmessage.Header
   186  		if _, ok := c.(PacketConn); ok {
   187  			p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
   188  		} else {
   189  			p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
   190  		}
   191  		c.Close()
   192  		if err != nil {
   193  			return dnsmessage.Parser{}, dnsmessage.Header{}, mapErr(err)
   194  		}
   195  		if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
   196  			return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
   197  		}
   198  		if h.Truncated { // see RFC 5966
   199  			continue
   200  		}
   201  		return p, h, nil
   202  	}
   203  	return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
   204  }
   205  
   206  // checkHeader performs basic sanity checks on the header.
   207  func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
   208  	if h.RCode == dnsmessage.RCodeNameError {
   209  		return errNoSuchHost
   210  	}
   211  
   212  	_, err := p.AnswerHeader()
   213  	if err != nil && err != dnsmessage.ErrSectionDone {
   214  		return errCannotUnmarshalDNSMessage
   215  	}
   216  
   217  	// libresolv continues to the next server when it receives
   218  	// an invalid referral response. See golang.org/issue/15434.
   219  	if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
   220  		return errLameReferral
   221  	}
   222  
   223  	if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
   224  		// None of the error codes make sense
   225  		// for the query we sent. If we didn't get
   226  		// a name error and we didn't get success,
   227  		// the server is behaving incorrectly or
   228  		// having temporary trouble.
   229  		if h.RCode == dnsmessage.RCodeServerFailure {
   230  			return errServerTemporarilyMisbehaving
   231  		}
   232  		return errServerMisbehaving
   233  	}
   234  
   235  	return nil
   236  }
   237  
   238  func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
   239  	for {
   240  		h, err := p.AnswerHeader()
   241  		if err == dnsmessage.ErrSectionDone {
   242  			return errNoSuchHost
   243  		}
   244  		if err != nil {
   245  			return errCannotUnmarshalDNSMessage
   246  		}
   247  		if h.Type == qtype {
   248  			return nil
   249  		}
   250  		if err := p.SkipAnswer(); err != nil {
   251  			return errCannotUnmarshalDNSMessage
   252  		}
   253  	}
   254  }
   255  
   256  // Do a lookup for a single name, which must be rooted
   257  // (otherwise answer will not find the answers).
   258  func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
   259  	var lastErr error
   260  	serverOffset := cfg.serverOffset()
   261  	sLen := uint32(len(cfg.servers))
   262  
   263  	n, err := dnsmessage.NewName(name)
   264  	if err != nil {
   265  		return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
   266  	}
   267  	q := dnsmessage.Question{
   268  		Name:  n,
   269  		Type:  qtype,
   270  		Class: dnsmessage.ClassINET,
   271  	}
   272  
   273  	for i := 0; i < cfg.attempts; i++ {
   274  		for j := uint32(0); j < sLen; j++ {
   275  			server := cfg.servers[(serverOffset+j)%sLen]
   276  
   277  			p, h, err := r.exchange(ctx, server, q, cfg.timeout, cfg.useTCP, cfg.trustAD)
   278  			if err != nil {
   279  				dnsErr := &DNSError{
   280  					Err:    err.Error(),
   281  					Name:   name,
   282  					Server: server,
   283  				}
   284  				if nerr, ok := err.(Error); ok && nerr.Timeout() {
   285  					dnsErr.IsTimeout = true
   286  				}
   287  				// Set IsTemporary for socket-level errors. Note that this flag
   288  				// may also be used to indicate a SERVFAIL response.
   289  				if _, ok := err.(*OpError); ok {
   290  					dnsErr.IsTemporary = true
   291  				}
   292  				lastErr = dnsErr
   293  				continue
   294  			}
   295  
   296  			if err := checkHeader(&p, h); err != nil {
   297  				dnsErr := &DNSError{
   298  					Err:    err.Error(),
   299  					Name:   name,
   300  					Server: server,
   301  				}
   302  				if err == errServerTemporarilyMisbehaving {
   303  					dnsErr.IsTemporary = true
   304  				}
   305  				if err == errNoSuchHost {
   306  					// The name does not exist, so trying
   307  					// another server won't help.
   308  
   309  					dnsErr.IsNotFound = true
   310  					return p, server, dnsErr
   311  				}
   312  				lastErr = dnsErr
   313  				continue
   314  			}
   315  
   316  			err = skipToAnswer(&p, qtype)
   317  			if err == nil {
   318  				return p, server, nil
   319  			}
   320  			lastErr = &DNSError{
   321  				Err:    err.Error(),
   322  				Name:   name,
   323  				Server: server,
   324  			}
   325  			if err == errNoSuchHost {
   326  				// The name does not exist, so trying another
   327  				// server won't help.
   328  
   329  				lastErr.(*DNSError).IsNotFound = true
   330  				return p, server, lastErr
   331  			}
   332  		}
   333  	}
   334  	return dnsmessage.Parser{}, "", lastErr
   335  }
   336  
   337  // A resolverConfig represents a DNS stub resolver configuration.
   338  type resolverConfig struct {
   339  	initOnce sync.Once // guards init of resolverConfig
   340  
   341  	// ch is used as a semaphore that only allows one lookup at a
   342  	// time to recheck resolv.conf.
   343  	ch          chan struct{} // guards lastChecked and modTime
   344  	lastChecked time.Time     // last time resolv.conf was checked
   345  
   346  	dnsConfig atomic.Pointer[dnsConfig] // parsed resolv.conf structure used in lookups
   347  }
   348  
   349  var resolvConf resolverConfig
   350  
   351  func getSystemDNSConfig() *dnsConfig {
   352  	resolvConf.tryUpdate("/etc/resolv.conf")
   353  	return resolvConf.dnsConfig.Load()
   354  }
   355  
   356  // init initializes conf and is only called via conf.initOnce.
   357  func (conf *resolverConfig) init() {
   358  	// Set dnsConfig and lastChecked so we don't parse
   359  	// resolv.conf twice the first time.
   360  	conf.dnsConfig.Store(dnsReadConfig("/etc/resolv.conf"))
   361  	conf.lastChecked = time.Now()
   362  
   363  	// Prepare ch so that only one update of resolverConfig may
   364  	// run at once.
   365  	conf.ch = make(chan struct{}, 1)
   366  }
   367  
   368  // tryUpdate tries to update conf with the named resolv.conf file.
   369  // The name variable only exists for testing. It is otherwise always
   370  // "/etc/resolv.conf".
   371  func (conf *resolverConfig) tryUpdate(name string) {
   372  	conf.initOnce.Do(conf.init)
   373  
   374  	if conf.dnsConfig.Load().noReload {
   375  		return
   376  	}
   377  
   378  	// Ensure only one update at a time checks resolv.conf.
   379  	if !conf.tryAcquireSema() {
   380  		return
   381  	}
   382  	defer conf.releaseSema()
   383  
   384  	now := time.Now()
   385  	if conf.lastChecked.After(now.Add(-5 * time.Second)) {
   386  		return
   387  	}
   388  	conf.lastChecked = now
   389  
   390  	switch runtime.GOOS {
   391  	case "windows":
   392  		// There's no file on disk, so don't bother checking
   393  		// and failing.
   394  		//
   395  		// The Windows implementation of dnsReadConfig (called
   396  		// below) ignores the name.
   397  	default:
   398  		var mtime time.Time
   399  		if fi, err := os.Stat(name); err == nil {
   400  			mtime = fi.ModTime()
   401  		}
   402  		if mtime.Equal(conf.dnsConfig.Load().mtime) {
   403  			return
   404  		}
   405  	}
   406  
   407  	dnsConf := dnsReadConfig(name)
   408  	conf.dnsConfig.Store(dnsConf)
   409  }
   410  
   411  func (conf *resolverConfig) tryAcquireSema() bool {
   412  	select {
   413  	case conf.ch <- struct{}{}:
   414  		return true
   415  	default:
   416  		return false
   417  	}
   418  }
   419  
   420  func (conf *resolverConfig) releaseSema() {
   421  	<-conf.ch
   422  }
   423  
   424  func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type, conf *dnsConfig) (dnsmessage.Parser, string, error) {
   425  	if !isDomainName(name) {
   426  		// We used to use "invalid domain name" as the error,
   427  		// but that is a detail of the specific lookup mechanism.
   428  		// Other lookups might allow broader name syntax
   429  		// (for example Multicast DNS allows UTF-8; see RFC 6762).
   430  		// For consistency with libc resolvers, report no such host.
   431  		return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
   432  	}
   433  
   434  	if conf == nil {
   435  		conf = getSystemDNSConfig()
   436  	}
   437  
   438  	var (
   439  		p      dnsmessage.Parser
   440  		server string
   441  		err    error
   442  	)
   443  	for _, fqdn := range conf.nameList(name) {
   444  		p, server, err = r.tryOneName(ctx, conf, fqdn, qtype)
   445  		if err == nil {
   446  			break
   447  		}
   448  		if nerr, ok := err.(Error); ok && nerr.Temporary() && r.strictErrors() {
   449  			// If we hit a temporary error with StrictErrors enabled,
   450  			// stop immediately instead of trying more names.
   451  			break
   452  		}
   453  	}
   454  	if err == nil {
   455  		return p, server, nil
   456  	}
   457  	if err, ok := err.(*DNSError); ok {
   458  		// Show original name passed to lookup, not suffixed one.
   459  		// In general we might have tried many suffixes; showing
   460  		// just one is misleading. See also golang.org/issue/6324.
   461  		err.Name = name
   462  	}
   463  	return dnsmessage.Parser{}, "", err
   464  }
   465  
   466  // avoidDNS reports whether this is a hostname for which we should not
   467  // use DNS. Currently this includes only .onion, per RFC 7686. See
   468  // golang.org/issue/13705. Does not cover .local names (RFC 6762),
   469  // see golang.org/issue/16739.
   470  func avoidDNS(name string) bool {
   471  	if name == "" {
   472  		return true
   473  	}
   474  	if name[len(name)-1] == '.' {
   475  		name = name[:len(name)-1]
   476  	}
   477  	return stringsHasSuffixFold(name, ".onion")
   478  }
   479  
   480  // nameList returns a list of names for sequential DNS queries.
   481  func (conf *dnsConfig) nameList(name string) []string {
   482  	if avoidDNS(name) {
   483  		return nil
   484  	}
   485  
   486  	// Check name length (see isDomainName).
   487  	l := len(name)
   488  	rooted := l > 0 && name[l-1] == '.'
   489  	if l > 254 || l == 254 && !rooted {
   490  		return nil
   491  	}
   492  
   493  	// If name is rooted (trailing dot), try only that name.
   494  	if rooted {
   495  		return []string{name}
   496  	}
   497  
   498  	hasNdots := count(name, '.') >= conf.ndots
   499  	name += "."
   500  	l++
   501  
   502  	// Build list of search choices.
   503  	names := make([]string, 0, 1+len(conf.search))
   504  	// If name has enough dots, try unsuffixed first.
   505  	if hasNdots {
   506  		names = append(names, name)
   507  	}
   508  	// Try suffixes that are not too long (see isDomainName).
   509  	for _, suffix := range conf.search {
   510  		if l+len(suffix) <= 254 {
   511  			names = append(names, name+suffix)
   512  		}
   513  	}
   514  	// Try unsuffixed, if not tried first above.
   515  	if !hasNdots {
   516  		names = append(names, name)
   517  	}
   518  	return names
   519  }
   520  
   521  // hostLookupOrder specifies the order of LookupHost lookup strategies.
   522  // It is basically a simplified representation of nsswitch.conf.
   523  // "files" means /etc/hosts.
   524  type hostLookupOrder int
   525  
   526  const (
   527  	// hostLookupCgo means defer to cgo.
   528  	hostLookupCgo      hostLookupOrder = iota
   529  	hostLookupFilesDNS                 // files first
   530  	hostLookupDNSFiles                 // dns first
   531  	hostLookupFiles                    // only files
   532  	hostLookupDNS                      // only DNS
   533  )
   534  
   535  var lookupOrderName = map[hostLookupOrder]string{
   536  	hostLookupCgo:      "cgo",
   537  	hostLookupFilesDNS: "files,dns",
   538  	hostLookupDNSFiles: "dns,files",
   539  	hostLookupFiles:    "files",
   540  	hostLookupDNS:      "dns",
   541  }
   542  
   543  func (o hostLookupOrder) String() string {
   544  	if s, ok := lookupOrderName[o]; ok {
   545  		return s
   546  	}
   547  	return "hostLookupOrder=" + itoa.Itoa(int(o)) + "??"
   548  }
   549  
   550  func (r *Resolver) goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder, conf *dnsConfig) (addrs []string, err error) {
   551  	if order == hostLookupFilesDNS || order == hostLookupFiles {
   552  		// Use entries from /etc/hosts if they match.
   553  		addrs, _ = lookupStaticHost(name)
   554  		if len(addrs) > 0 {
   555  			return
   556  		}
   557  
   558  		if order == hostLookupFiles {
   559  			return nil, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
   560  		}
   561  	}
   562  	ips, _, err := r.goLookupIPCNAMEOrder(ctx, "ip", name, order, conf)
   563  	if err != nil {
   564  		return
   565  	}
   566  	addrs = make([]string, 0, len(ips))
   567  	for _, ip := range ips {
   568  		addrs = append(addrs, ip.String())
   569  	}
   570  	return
   571  }
   572  
   573  // lookup entries from /etc/hosts
   574  func goLookupIPFiles(name string) (addrs []IPAddr, canonical string) {
   575  	addr, canonical := lookupStaticHost(name)
   576  	for _, haddr := range addr {
   577  		haddr, zone := splitHostZone(haddr)
   578  		if ip := ParseIP(haddr); ip != nil {
   579  			addr := IPAddr{IP: ip, Zone: zone}
   580  			addrs = append(addrs, addr)
   581  		}
   582  	}
   583  	sortByRFC6724(addrs)
   584  	return addrs, canonical
   585  }
   586  
   587  // goLookupIP is the native Go implementation of LookupIP.
   588  // The libc versions are in cgo_*.go.
   589  func (r *Resolver) goLookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
   590  	order, conf := systemConf().hostLookupOrder(r, host)
   591  	addrs, _, err = r.goLookupIPCNAMEOrder(ctx, network, host, order, conf)
   592  	return
   593  }
   594  
   595  func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name string, order hostLookupOrder, conf *dnsConfig) (addrs []IPAddr, cname dnsmessage.Name, err error) {
   596  	if order == hostLookupFilesDNS || order == hostLookupFiles {
   597  		var canonical string
   598  		addrs, canonical = goLookupIPFiles(name)
   599  
   600  		if len(addrs) > 0 {
   601  			var err error
   602  			cname, err = dnsmessage.NewName(canonical)
   603  			if err != nil {
   604  				return nil, dnsmessage.Name{}, err
   605  			}
   606  			return addrs, cname, nil
   607  		}
   608  
   609  		if order == hostLookupFiles {
   610  			return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
   611  		}
   612  	}
   613  
   614  	if !isDomainName(name) {
   615  		// See comment in func lookup above about use of errNoSuchHost.
   616  		return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
   617  	}
   618  	type result struct {
   619  		p      dnsmessage.Parser
   620  		server string
   621  		error
   622  	}
   623  
   624  	if conf == nil {
   625  		conf = getSystemDNSConfig()
   626  	}
   627  
   628  	lane := make(chan result, 1)
   629  	qtypes := []dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA}
   630  	if network == "CNAME" {
   631  		qtypes = append(qtypes, dnsmessage.TypeCNAME)
   632  	}
   633  	switch ipVersion(network) {
   634  	case '4':
   635  		qtypes = []dnsmessage.Type{dnsmessage.TypeA}
   636  	case '6':
   637  		qtypes = []dnsmessage.Type{dnsmessage.TypeAAAA}
   638  	}
   639  	var queryFn func(fqdn string, qtype dnsmessage.Type)
   640  	var responseFn func(fqdn string, qtype dnsmessage.Type) result
   641  	if conf.singleRequest {
   642  		queryFn = func(fqdn string, qtype dnsmessage.Type) {}
   643  		responseFn = func(fqdn string, qtype dnsmessage.Type) result {
   644  			dnsWaitGroup.Add(1)
   645  			defer dnsWaitGroup.Done()
   646  			p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
   647  			return result{p, server, err}
   648  		}
   649  	} else {
   650  		queryFn = func(fqdn string, qtype dnsmessage.Type) {
   651  			dnsWaitGroup.Add(1)
   652  			go func(qtype dnsmessage.Type) {
   653  				p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
   654  				lane <- result{p, server, err}
   655  				dnsWaitGroup.Done()
   656  			}(qtype)
   657  		}
   658  		responseFn = func(fqdn string, qtype dnsmessage.Type) result {
   659  			return <-lane
   660  		}
   661  	}
   662  	var lastErr error
   663  	for _, fqdn := range conf.nameList(name) {
   664  		for _, qtype := range qtypes {
   665  			queryFn(fqdn, qtype)
   666  		}
   667  		hitStrictError := false
   668  		for _, qtype := range qtypes {
   669  			result := responseFn(fqdn, qtype)
   670  			if result.error != nil {
   671  				if nerr, ok := result.error.(Error); ok && nerr.Temporary() && r.strictErrors() {
   672  					// This error will abort the nameList loop.
   673  					hitStrictError = true
   674  					lastErr = result.error
   675  				} else if lastErr == nil || fqdn == name+"." {
   676  					// Prefer error for original name.
   677  					lastErr = result.error
   678  				}
   679  				continue
   680  			}
   681  
   682  			// Presotto says it's okay to assume that servers listed in
   683  			// /etc/resolv.conf are recursive resolvers.
   684  			//
   685  			// We asked for recursion, so it should have included all the
   686  			// answers we need in this one packet.
   687  			//
   688  			// Further, RFC 1034 section 4.3.1 says that "the recursive
   689  			// response to a query will be... The answer to the query,
   690  			// possibly preface by one or more CNAME RRs that specify
   691  			// aliases encountered on the way to an answer."
   692  			//
   693  			// Therefore, we should be able to assume that we can ignore
   694  			// CNAMEs and that the A and AAAA records we requested are
   695  			// for the canonical name.
   696  
   697  		loop:
   698  			for {
   699  				h, err := result.p.AnswerHeader()
   700  				if err != nil && err != dnsmessage.ErrSectionDone {
   701  					lastErr = &DNSError{
   702  						Err:    "cannot marshal DNS message",
   703  						Name:   name,
   704  						Server: result.server,
   705  					}
   706  				}
   707  				if err != nil {
   708  					break
   709  				}
   710  				switch h.Type {
   711  				case dnsmessage.TypeA:
   712  					a, err := result.p.AResource()
   713  					if err != nil {
   714  						lastErr = &DNSError{
   715  							Err:    "cannot marshal DNS message",
   716  							Name:   name,
   717  							Server: result.server,
   718  						}
   719  						break loop
   720  					}
   721  					addrs = append(addrs, IPAddr{IP: IP(a.A[:])})
   722  					if cname.Length == 0 && h.Name.Length != 0 {
   723  						cname = h.Name
   724  					}
   725  
   726  				case dnsmessage.TypeAAAA:
   727  					aaaa, err := result.p.AAAAResource()
   728  					if err != nil {
   729  						lastErr = &DNSError{
   730  							Err:    "cannot marshal DNS message",
   731  							Name:   name,
   732  							Server: result.server,
   733  						}
   734  						break loop
   735  					}
   736  					addrs = append(addrs, IPAddr{IP: IP(aaaa.AAAA[:])})
   737  					if cname.Length == 0 && h.Name.Length != 0 {
   738  						cname = h.Name
   739  					}
   740  
   741  				case dnsmessage.TypeCNAME:
   742  					c, err := result.p.CNAMEResource()
   743  					if err != nil {
   744  						lastErr = &DNSError{
   745  							Err:    "cannot marshal DNS message",
   746  							Name:   name,
   747  							Server: result.server,
   748  						}
   749  						break loop
   750  					}
   751  					if cname.Length == 0 && c.CNAME.Length > 0 {
   752  						cname = c.CNAME
   753  					}
   754  
   755  				default:
   756  					if err := result.p.SkipAnswer(); err != nil {
   757  						lastErr = &DNSError{
   758  							Err:    "cannot marshal DNS message",
   759  							Name:   name,
   760  							Server: result.server,
   761  						}
   762  						break loop
   763  					}
   764  					continue
   765  				}
   766  			}
   767  		}
   768  		if hitStrictError {
   769  			// If either family hit an error with StrictErrors enabled,
   770  			// discard all addresses. This ensures that network flakiness
   771  			// cannot turn a dualstack hostname IPv4/IPv6-only.
   772  			addrs = nil
   773  			break
   774  		}
   775  		if len(addrs) > 0 || network == "CNAME" && cname.Length > 0 {
   776  			break
   777  		}
   778  	}
   779  	if lastErr, ok := lastErr.(*DNSError); ok {
   780  		// Show original name passed to lookup, not suffixed one.
   781  		// In general we might have tried many suffixes; showing
   782  		// just one is misleading. See also golang.org/issue/6324.
   783  		lastErr.Name = name
   784  	}
   785  	sortByRFC6724(addrs)
   786  	if len(addrs) == 0 && !(network == "CNAME" && cname.Length > 0) {
   787  		if order == hostLookupDNSFiles {
   788  			var canonical string
   789  			addrs, canonical = goLookupIPFiles(name)
   790  			if len(addrs) > 0 {
   791  				var err error
   792  				cname, err = dnsmessage.NewName(canonical)
   793  				if err != nil {
   794  					return nil, dnsmessage.Name{}, err
   795  				}
   796  				return addrs, cname, nil
   797  			}
   798  		}
   799  		if lastErr != nil {
   800  			return nil, dnsmessage.Name{}, lastErr
   801  		}
   802  	}
   803  	return addrs, cname, nil
   804  }
   805  
   806  // goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME.
   807  func (r *Resolver) goLookupCNAME(ctx context.Context, host string, order hostLookupOrder, conf *dnsConfig) (string, error) {
   808  	_, cname, err := r.goLookupIPCNAMEOrder(ctx, "CNAME", host, order, conf)
   809  	return cname.String(), err
   810  }
   811  
   812  // goLookupPTR is the native Go implementation of LookupAddr.
   813  // Used only if cgoLookupPTR refuses to handle the request (that is,
   814  // only if cgoLookupPTR is the stub in cgo_stub.go).
   815  // Normally we let cgo use the C library resolver instead of depending
   816  // on our lookup code, so that Go and C get the same answers.
   817  func (r *Resolver) goLookupPTR(ctx context.Context, addr string, conf *dnsConfig) ([]string, error) {
   818  	names := lookupStaticAddr(addr)
   819  	if len(names) > 0 {
   820  		return names, nil
   821  	}
   822  	arpa, err := reverseaddr(addr)
   823  	if err != nil {
   824  		return nil, err
   825  	}
   826  	p, server, err := r.lookup(ctx, arpa, dnsmessage.TypePTR, conf)
   827  	if err != nil {
   828  		return nil, err
   829  	}
   830  	var ptrs []string
   831  	for {
   832  		h, err := p.AnswerHeader()
   833  		if err == dnsmessage.ErrSectionDone {
   834  			break
   835  		}
   836  		if err != nil {
   837  			return nil, &DNSError{
   838  				Err:    "cannot marshal DNS message",
   839  				Name:   addr,
   840  				Server: server,
   841  			}
   842  		}
   843  		if h.Type != dnsmessage.TypePTR {
   844  			err := p.SkipAnswer()
   845  			if err != nil {
   846  				return nil, &DNSError{
   847  					Err:    "cannot marshal DNS message",
   848  					Name:   addr,
   849  					Server: server,
   850  				}
   851  			}
   852  			continue
   853  		}
   854  		ptr, err := p.PTRResource()
   855  		if err != nil {
   856  			return nil, &DNSError{
   857  				Err:    "cannot marshal DNS message",
   858  				Name:   addr,
   859  				Server: server,
   860  			}
   861  		}
   862  		ptrs = append(ptrs, ptr.PTR.String())
   863  
   864  	}
   865  	return ptrs, nil
   866  }
   867  

View as plain text