Source file src/net/resolverdialfunc_test.go

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build !js
     6  
     7  // Test that Resolver.Dial can be a func returning an in-memory net.Conn
     8  // speaking DNS.
     9  
    10  package net
    11  
    12  import (
    13  	"bytes"
    14  	"context"
    15  	"errors"
    16  	"fmt"
    17  	"reflect"
    18  	"sort"
    19  	"testing"
    20  	"time"
    21  
    22  	"golang.org/x/net/dns/dnsmessage"
    23  )
    24  
    25  func TestResolverDialFunc(t *testing.T) {
    26  	r := &Resolver{
    27  		PreferGo: true,
    28  		Dial: newResolverDialFunc(&resolverDialHandler{
    29  			StartDial: func(network, address string) error {
    30  				t.Logf("StartDial(%q, %q) ...", network, address)
    31  				return nil
    32  			},
    33  			Question: func(h dnsmessage.Header, q dnsmessage.Question) {
    34  				t.Logf("Header: %+v for %q (type=%v, class=%v)", h,
    35  					q.Name.String(), q.Type, q.Class)
    36  			},
    37  			// TODO: add test without HandleA* hooks specified at all, that Go
    38  			// doesn't issue retries; map to something terminal.
    39  			HandleA: func(w AWriter, name string) error {
    40  				w.AddIP([4]byte{1, 2, 3, 4})
    41  				w.AddIP([4]byte{5, 6, 7, 8})
    42  				return nil
    43  			},
    44  			HandleAAAA: func(w AAAAWriter, name string) error {
    45  				w.AddIP([16]byte{1: 1, 15: 15})
    46  				w.AddIP([16]byte{2: 2, 14: 14})
    47  				return nil
    48  			},
    49  			HandleSRV: func(w SRVWriter, name string) error {
    50  				w.AddSRV(1, 2, 80, "foo.bar.")
    51  				w.AddSRV(2, 3, 81, "bar.baz.")
    52  				return nil
    53  			},
    54  		}),
    55  	}
    56  	ctx := context.Background()
    57  	const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld."
    58  
    59  	t.Run("LookupIP", func(t *testing.T) {
    60  		ips, err := r.LookupIP(ctx, "ip", fakeDomain)
    61  		if err != nil {
    62  			t.Fatal(err)
    63  		}
    64  		if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !reflect.DeepEqual(got, want) {
    65  			t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want)
    66  		}
    67  	})
    68  
    69  	t.Run("LookupSRV", func(t *testing.T) {
    70  		_, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain)
    71  		if err != nil {
    72  			t.Fatal(err)
    73  		}
    74  		want := []*SRV{
    75  			{
    76  				Target:   "foo.bar.",
    77  				Port:     80,
    78  				Priority: 1,
    79  				Weight:   2,
    80  			},
    81  			{
    82  				Target:   "bar.baz.",
    83  				Port:     81,
    84  				Priority: 2,
    85  				Weight:   3,
    86  			},
    87  		}
    88  		if !reflect.DeepEqual(got, want) {
    89  			t.Errorf("wrong result. got:")
    90  			for _, r := range got {
    91  				t.Logf("  - %+v", r)
    92  			}
    93  		}
    94  	})
    95  }
    96  
    97  func sortedIPStrings(ips []IP) []string {
    98  	ret := make([]string, len(ips))
    99  	for i, ip := range ips {
   100  		ret[i] = ip.String()
   101  	}
   102  	sort.Strings(ret)
   103  	return ret
   104  }
   105  
   106  func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) {
   107  	return func(ctx context.Context, network, address string) (Conn, error) {
   108  		a := &resolverFuncConn{
   109  			h:       h,
   110  			network: network,
   111  			address: address,
   112  			ttl:     10, // 10 second default if unset
   113  		}
   114  		if h.StartDial != nil {
   115  			if err := h.StartDial(network, address); err != nil {
   116  				return nil, err
   117  			}
   118  		}
   119  		return a, nil
   120  	}
   121  }
   122  
   123  type resolverDialHandler struct {
   124  	// StartDial, if non-nil, is called when Go first calls Resolver.Dial.
   125  	// Any error returned aborts the dial and is returned unwrapped.
   126  	StartDial func(network, address string) error
   127  
   128  	Question func(dnsmessage.Header, dnsmessage.Question)
   129  
   130  	// err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2).
   131  	// A nil error means success.
   132  	HandleA    func(w AWriter, name string) error
   133  	HandleAAAA func(w AAAAWriter, name string) error
   134  	HandleSRV  func(w SRVWriter, name string) error
   135  }
   136  
   137  type ResponseWriter struct{ a *resolverFuncConn }
   138  
   139  func (w ResponseWriter) header() dnsmessage.ResourceHeader {
   140  	q := w.a.q
   141  	return dnsmessage.ResourceHeader{
   142  		Name:  q.Name,
   143  		Type:  q.Type,
   144  		Class: q.Class,
   145  		TTL:   w.a.ttl,
   146  	}
   147  }
   148  
   149  // SetTTL sets the TTL for subsequent written resources.
   150  // Once a resource has been written, SetTTL calls are no-ops.
   151  // That is, it can only be called at most once, before anything
   152  // else is written.
   153  func (w ResponseWriter) SetTTL(seconds uint32) {
   154  	// ... intention is last one wins and mutates all previously
   155  	// written records too, but that's a little annoying.
   156  	// But it's also annoying if the requirement is it needs to be set
   157  	// last.
   158  	// And it's also annoying if it's possible for users to set
   159  	// different TTLs per Answer.
   160  	if w.a.wrote {
   161  		return
   162  	}
   163  	w.a.ttl = seconds
   164  
   165  }
   166  
   167  type AWriter struct{ ResponseWriter }
   168  
   169  func (w AWriter) AddIP(v4 [4]byte) {
   170  	w.a.wrote = true
   171  	err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4})
   172  	if err != nil {
   173  		panic(err)
   174  	}
   175  }
   176  
   177  type AAAAWriter struct{ ResponseWriter }
   178  
   179  func (w AAAAWriter) AddIP(v6 [16]byte) {
   180  	w.a.wrote = true
   181  	err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6})
   182  	if err != nil {
   183  		panic(err)
   184  	}
   185  }
   186  
   187  type SRVWriter struct{ ResponseWriter }
   188  
   189  // AddSRV adds a SRV record. The target name must end in a period and
   190  // be 63 bytes or fewer.
   191  func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error {
   192  	targetName, err := dnsmessage.NewName(target)
   193  	if err != nil {
   194  		return err
   195  	}
   196  	w.a.wrote = true
   197  	err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{
   198  		Priority: priority,
   199  		Weight:   weight,
   200  		Port:     port,
   201  		Target:   targetName,
   202  	})
   203  	if err != nil {
   204  		panic(err) // internal fault, not user
   205  	}
   206  	return nil
   207  }
   208  
   209  var (
   210  	ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN
   211  	ErrRefused  = errors.New("refused")             // maps to RCode5, REFUSED
   212  )
   213  
   214  type resolverFuncConn struct {
   215  	h       *resolverDialHandler
   216  	ctx     context.Context
   217  	network string
   218  	address string
   219  	builder *dnsmessage.Builder
   220  	q       dnsmessage.Question
   221  	ttl     uint32
   222  	wrote   bool
   223  
   224  	rbuf bytes.Buffer
   225  }
   226  
   227  func (*resolverFuncConn) Close() error                       { return nil }
   228  func (*resolverFuncConn) LocalAddr() Addr                    { return someaddr{} }
   229  func (*resolverFuncConn) RemoteAddr() Addr                   { return someaddr{} }
   230  func (*resolverFuncConn) SetDeadline(t time.Time) error      { return nil }
   231  func (*resolverFuncConn) SetReadDeadline(t time.Time) error  { return nil }
   232  func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil }
   233  
   234  func (a *resolverFuncConn) Read(p []byte) (n int, err error) {
   235  	return a.rbuf.Read(p)
   236  }
   237  
   238  func (a *resolverFuncConn) Write(packet []byte) (n int, err error) {
   239  	if len(packet) < 2 {
   240  		return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet))
   241  	}
   242  	reqLen := int(packet[0])<<8 | int(packet[1])
   243  	req := packet[2:]
   244  	if len(req) != reqLen {
   245  		return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req))
   246  	}
   247  
   248  	var parser dnsmessage.Parser
   249  	h, err := parser.Start(req)
   250  	if err != nil {
   251  		// TODO: hook
   252  		return 0, err
   253  	}
   254  	q, err := parser.Question()
   255  	hadQ := (err == nil)
   256  	if err == nil && a.h.Question != nil {
   257  		a.h.Question(h, q)
   258  	}
   259  	if err != nil && err != dnsmessage.ErrSectionDone {
   260  		return 0, err
   261  	}
   262  
   263  	resh := h
   264  	resh.Response = true
   265  	resh.Authoritative = true
   266  	if hadQ {
   267  		resh.RCode = dnsmessage.RCodeSuccess
   268  	} else {
   269  		resh.RCode = dnsmessage.RCodeNotImplemented
   270  	}
   271  	a.rbuf.Grow(514)
   272  	a.rbuf.WriteByte('X') // reserved header for beu16 length
   273  	a.rbuf.WriteByte('Y') // reserved header for beu16 length
   274  	builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh)
   275  	a.builder = &builder
   276  	if hadQ {
   277  		a.q = q
   278  		a.builder.StartQuestions()
   279  		err := a.builder.Question(q)
   280  		if err != nil {
   281  			return 0, fmt.Errorf("Question: %w", err)
   282  		}
   283  		a.builder.StartAnswers()
   284  		switch q.Type {
   285  		case dnsmessage.TypeA:
   286  			if a.h.HandleA != nil {
   287  				resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String()))
   288  			}
   289  		case dnsmessage.TypeAAAA:
   290  			if a.h.HandleAAAA != nil {
   291  				resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String()))
   292  			}
   293  		case dnsmessage.TypeSRV:
   294  			if a.h.HandleSRV != nil {
   295  				resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String()))
   296  			}
   297  		}
   298  	}
   299  	tcpRes, err := builder.Finish()
   300  	if err != nil {
   301  		return 0, fmt.Errorf("Finish: %w", err)
   302  	}
   303  
   304  	n = len(tcpRes) - 2
   305  	tcpRes[0] = byte(n >> 8)
   306  	tcpRes[1] = byte(n)
   307  	a.rbuf.Write(tcpRes[2:])
   308  
   309  	return len(packet), nil
   310  }
   311  
   312  type someaddr struct{}
   313  
   314  func (someaddr) Network() string { return "unused" }
   315  func (someaddr) String() string  { return "unused-someaddr" }
   316  
   317  func mapRCode(err error) dnsmessage.RCode {
   318  	switch err {
   319  	case nil:
   320  		return dnsmessage.RCodeSuccess
   321  	case ErrNotExist:
   322  		return dnsmessage.RCodeNameError
   323  	case ErrRefused:
   324  		return dnsmessage.RCodeRefused
   325  	default:
   326  		return dnsmessage.RCodeServerFailure
   327  	}
   328  }
   329  

View as plain text