Source file
src/net/resolverdialfunc_test.go
1
2
3
4
5
6
7
8
9
10 package net
11
12 import (
13 "bytes"
14 "context"
15 "errors"
16 "fmt"
17 "reflect"
18 "sort"
19 "testing"
20 "time"
21
22 "golang.org/x/net/dns/dnsmessage"
23 )
24
25 func TestResolverDialFunc(t *testing.T) {
26 r := &Resolver{
27 PreferGo: true,
28 Dial: newResolverDialFunc(&resolverDialHandler{
29 StartDial: func(network, address string) error {
30 t.Logf("StartDial(%q, %q) ...", network, address)
31 return nil
32 },
33 Question: func(h dnsmessage.Header, q dnsmessage.Question) {
34 t.Logf("Header: %+v for %q (type=%v, class=%v)", h,
35 q.Name.String(), q.Type, q.Class)
36 },
37
38
39 HandleA: func(w AWriter, name string) error {
40 w.AddIP([4]byte{1, 2, 3, 4})
41 w.AddIP([4]byte{5, 6, 7, 8})
42 return nil
43 },
44 HandleAAAA: func(w AAAAWriter, name string) error {
45 w.AddIP([16]byte{1: 1, 15: 15})
46 w.AddIP([16]byte{2: 2, 14: 14})
47 return nil
48 },
49 HandleSRV: func(w SRVWriter, name string) error {
50 w.AddSRV(1, 2, 80, "foo.bar.")
51 w.AddSRV(2, 3, 81, "bar.baz.")
52 return nil
53 },
54 }),
55 }
56 ctx := context.Background()
57 const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld."
58
59 t.Run("LookupIP", func(t *testing.T) {
60 ips, err := r.LookupIP(ctx, "ip", fakeDomain)
61 if err != nil {
62 t.Fatal(err)
63 }
64 if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !reflect.DeepEqual(got, want) {
65 t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want)
66 }
67 })
68
69 t.Run("LookupSRV", func(t *testing.T) {
70 _, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain)
71 if err != nil {
72 t.Fatal(err)
73 }
74 want := []*SRV{
75 {
76 Target: "foo.bar.",
77 Port: 80,
78 Priority: 1,
79 Weight: 2,
80 },
81 {
82 Target: "bar.baz.",
83 Port: 81,
84 Priority: 2,
85 Weight: 3,
86 },
87 }
88 if !reflect.DeepEqual(got, want) {
89 t.Errorf("wrong result. got:")
90 for _, r := range got {
91 t.Logf(" - %+v", r)
92 }
93 }
94 })
95 }
96
97 func sortedIPStrings(ips []IP) []string {
98 ret := make([]string, len(ips))
99 for i, ip := range ips {
100 ret[i] = ip.String()
101 }
102 sort.Strings(ret)
103 return ret
104 }
105
106 func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) {
107 return func(ctx context.Context, network, address string) (Conn, error) {
108 a := &resolverFuncConn{
109 h: h,
110 network: network,
111 address: address,
112 ttl: 10,
113 }
114 if h.StartDial != nil {
115 if err := h.StartDial(network, address); err != nil {
116 return nil, err
117 }
118 }
119 return a, nil
120 }
121 }
122
123 type resolverDialHandler struct {
124
125
126 StartDial func(network, address string) error
127
128 Question func(dnsmessage.Header, dnsmessage.Question)
129
130
131
132 HandleA func(w AWriter, name string) error
133 HandleAAAA func(w AAAAWriter, name string) error
134 HandleSRV func(w SRVWriter, name string) error
135 }
136
137 type ResponseWriter struct{ a *resolverFuncConn }
138
139 func (w ResponseWriter) header() dnsmessage.ResourceHeader {
140 q := w.a.q
141 return dnsmessage.ResourceHeader{
142 Name: q.Name,
143 Type: q.Type,
144 Class: q.Class,
145 TTL: w.a.ttl,
146 }
147 }
148
149
150
151
152
153 func (w ResponseWriter) SetTTL(seconds uint32) {
154
155
156
157
158
159
160 if w.a.wrote {
161 return
162 }
163 w.a.ttl = seconds
164
165 }
166
167 type AWriter struct{ ResponseWriter }
168
169 func (w AWriter) AddIP(v4 [4]byte) {
170 w.a.wrote = true
171 err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4})
172 if err != nil {
173 panic(err)
174 }
175 }
176
177 type AAAAWriter struct{ ResponseWriter }
178
179 func (w AAAAWriter) AddIP(v6 [16]byte) {
180 w.a.wrote = true
181 err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6})
182 if err != nil {
183 panic(err)
184 }
185 }
186
187 type SRVWriter struct{ ResponseWriter }
188
189
190
191 func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error {
192 targetName, err := dnsmessage.NewName(target)
193 if err != nil {
194 return err
195 }
196 w.a.wrote = true
197 err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{
198 Priority: priority,
199 Weight: weight,
200 Port: port,
201 Target: targetName,
202 })
203 if err != nil {
204 panic(err)
205 }
206 return nil
207 }
208
209 var (
210 ErrNotExist = errors.New("name does not exist")
211 ErrRefused = errors.New("refused")
212 )
213
214 type resolverFuncConn struct {
215 h *resolverDialHandler
216 network string
217 address string
218 builder *dnsmessage.Builder
219 q dnsmessage.Question
220 ttl uint32
221 wrote bool
222
223 rbuf bytes.Buffer
224 }
225
226 func (*resolverFuncConn) Close() error { return nil }
227 func (*resolverFuncConn) LocalAddr() Addr { return someaddr{} }
228 func (*resolverFuncConn) RemoteAddr() Addr { return someaddr{} }
229 func (*resolverFuncConn) SetDeadline(t time.Time) error { return nil }
230 func (*resolverFuncConn) SetReadDeadline(t time.Time) error { return nil }
231 func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil }
232
233 func (a *resolverFuncConn) Read(p []byte) (n int, err error) {
234 return a.rbuf.Read(p)
235 }
236
237 func (a *resolverFuncConn) Write(packet []byte) (n int, err error) {
238 if len(packet) < 2 {
239 return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet))
240 }
241 reqLen := int(packet[0])<<8 | int(packet[1])
242 req := packet[2:]
243 if len(req) != reqLen {
244 return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req))
245 }
246
247 var parser dnsmessage.Parser
248 h, err := parser.Start(req)
249 if err != nil {
250
251 return 0, err
252 }
253 q, err := parser.Question()
254 hadQ := (err == nil)
255 if err == nil && a.h.Question != nil {
256 a.h.Question(h, q)
257 }
258 if err != nil && err != dnsmessage.ErrSectionDone {
259 return 0, err
260 }
261
262 resh := h
263 resh.Response = true
264 resh.Authoritative = true
265 if hadQ {
266 resh.RCode = dnsmessage.RCodeSuccess
267 } else {
268 resh.RCode = dnsmessage.RCodeNotImplemented
269 }
270 a.rbuf.Grow(514)
271 a.rbuf.WriteByte('X')
272 a.rbuf.WriteByte('Y')
273 builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh)
274 a.builder = &builder
275 if hadQ {
276 a.q = q
277 a.builder.StartQuestions()
278 err := a.builder.Question(q)
279 if err != nil {
280 return 0, fmt.Errorf("Question: %w", err)
281 }
282 a.builder.StartAnswers()
283 switch q.Type {
284 case dnsmessage.TypeA:
285 if a.h.HandleA != nil {
286 resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String()))
287 }
288 case dnsmessage.TypeAAAA:
289 if a.h.HandleAAAA != nil {
290 resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String()))
291 }
292 case dnsmessage.TypeSRV:
293 if a.h.HandleSRV != nil {
294 resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String()))
295 }
296 }
297 }
298 tcpRes, err := builder.Finish()
299 if err != nil {
300 return 0, fmt.Errorf("Finish: %w", err)
301 }
302
303 n = len(tcpRes) - 2
304 tcpRes[0] = byte(n >> 8)
305 tcpRes[1] = byte(n)
306 a.rbuf.Write(tcpRes[2:])
307
308 return len(packet), nil
309 }
310
311 type someaddr struct{}
312
313 func (someaddr) Network() string { return "unused" }
314 func (someaddr) String() string { return "unused-someaddr" }
315
316 func mapRCode(err error) dnsmessage.RCode {
317 switch err {
318 case nil:
319 return dnsmessage.RCodeSuccess
320 case ErrNotExist:
321 return dnsmessage.RCodeNameError
322 case ErrRefused:
323 return dnsmessage.RCodeRefused
324 default:
325 return dnsmessage.RCodeServerFailure
326 }
327 }
328
View as plain text