Source file
src/net/dnsclient_unix_test.go
1
2
3
4
5
6
7 package net
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "os"
14 "path"
15 "reflect"
16 "strings"
17 "sync"
18 "sync/atomic"
19 "testing"
20 "time"
21
22 "golang.org/x/net/dns/dnsmessage"
23 )
24
25 var goResolver = Resolver{PreferGo: true}
26
27
28 var TestAddr = [4]byte{0xc0, 0x00, 0x02, 0x01}
29
30
31 var TestAddr6 = [16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
32
33 func mustNewName(name string) dnsmessage.Name {
34 nn, err := dnsmessage.NewName(name)
35 if err != nil {
36 panic(fmt.Sprint("creating name: ", err))
37 }
38 return nn
39 }
40
41 func mustQuestion(name string, qtype dnsmessage.Type, class dnsmessage.Class) dnsmessage.Question {
42 return dnsmessage.Question{
43 Name: mustNewName(name),
44 Type: qtype,
45 Class: class,
46 }
47 }
48
49 var dnsTransportFallbackTests = []struct {
50 server string
51 question dnsmessage.Question
52 timeout int
53 rcode dnsmessage.RCode
54 }{
55
56
57 {"8.8.8.8:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 2, dnsmessage.RCodeSuccess},
58 {"8.8.4.4:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 4, dnsmessage.RCodeSuccess},
59 }
60
61 func TestDNSTransportFallback(t *testing.T) {
62 fake := fakeDNSServer{
63 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
64 r := dnsmessage.Message{
65 Header: dnsmessage.Header{
66 ID: q.Header.ID,
67 Response: true,
68 RCode: dnsmessage.RCodeSuccess,
69 },
70 Questions: q.Questions,
71 }
72 if n == "udp" {
73 r.Header.Truncated = true
74 }
75 return r, nil
76 },
77 }
78 r := Resolver{PreferGo: true, Dial: fake.DialContext}
79 for _, tt := range dnsTransportFallbackTests {
80 ctx, cancel := context.WithCancel(context.Background())
81 defer cancel()
82 _, h, err := r.exchange(ctx, tt.server, tt.question, time.Second, useUDPOrTCP)
83 if err != nil {
84 t.Error(err)
85 continue
86 }
87 if h.RCode != tt.rcode {
88 t.Errorf("got %v from %v; want %v", h.RCode, tt.server, tt.rcode)
89 continue
90 }
91 }
92 }
93
94
95
96 var specialDomainNameTests = []struct {
97 question dnsmessage.Question
98 rcode dnsmessage.RCode
99 }{
100
101
102 {mustQuestion("1.0.168.192.in-addr.arpa.", dnsmessage.TypePTR, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
103 {mustQuestion("test.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
104 {mustQuestion("example.com.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeSuccess},
105
106
107
108
109
110 {mustQuestion("localhost.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
111 {mustQuestion("invalid.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
112 }
113
114 func TestSpecialDomainName(t *testing.T) {
115 fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
116 r := dnsmessage.Message{
117 Header: dnsmessage.Header{
118 ID: q.ID,
119 Response: true,
120 },
121 Questions: q.Questions,
122 }
123
124 switch q.Questions[0].Name.String() {
125 case "example.com.":
126 r.Header.RCode = dnsmessage.RCodeSuccess
127 default:
128 r.Header.RCode = dnsmessage.RCodeNameError
129 }
130
131 return r, nil
132 }}
133 r := Resolver{PreferGo: true, Dial: fake.DialContext}
134 server := "8.8.8.8:53"
135 for _, tt := range specialDomainNameTests {
136 ctx, cancel := context.WithCancel(context.Background())
137 defer cancel()
138 _, h, err := r.exchange(ctx, server, tt.question, 3*time.Second, useUDPOrTCP)
139 if err != nil {
140 t.Error(err)
141 continue
142 }
143 if h.RCode != tt.rcode {
144 t.Errorf("got %v from %v; want %v", h.RCode, server, tt.rcode)
145 continue
146 }
147 }
148 }
149
150
151 func TestAvoidDNSName(t *testing.T) {
152 tests := []struct {
153 name string
154 avoid bool
155 }{
156 {"foo.com", false},
157 {"foo.com.", false},
158
159 {"foo.onion.", true},
160 {"foo.onion", true},
161 {"foo.ONION", true},
162 {"foo.ONION.", true},
163
164
165 {"foo.local.", false},
166 {"foo.local", false},
167 {"foo.LOCAL", false},
168 {"foo.LOCAL.", false},
169
170 {"", true},
171
172
173
174
175
176
177
178 {"local", false},
179 {"onion", false},
180 {"local.", false},
181 {"onion.", false},
182 }
183 for _, tt := range tests {
184 got := avoidDNS(tt.name)
185 if got != tt.avoid {
186 t.Errorf("avoidDNS(%q) = %v; want %v", tt.name, got, tt.avoid)
187 }
188 }
189 }
190
191 var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
192 r := dnsmessage.Message{
193 Header: dnsmessage.Header{
194 ID: q.ID,
195 Response: true,
196 },
197 Questions: q.Questions,
198 }
199 if len(q.Questions) == 1 && q.Questions[0].Type == dnsmessage.TypeA {
200 r.Answers = []dnsmessage.Resource{
201 {
202 Header: dnsmessage.ResourceHeader{
203 Name: q.Questions[0].Name,
204 Type: dnsmessage.TypeA,
205 Class: dnsmessage.ClassINET,
206 Length: 4,
207 },
208 Body: &dnsmessage.AResource{
209 A: TestAddr,
210 },
211 },
212 }
213 }
214 return r, nil
215 }}
216
217
218 func TestLookupTorOnion(t *testing.T) {
219 defer dnsWaitGroup.Wait()
220 r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
221 addrs, err := r.LookupIPAddr(context.Background(), "foo.onion")
222 if err != nil {
223 t.Fatalf("lookup = %v; want nil", err)
224 }
225 if len(addrs) > 0 {
226 t.Errorf("unexpected addresses: %v", addrs)
227 }
228 }
229
230 type resolvConfTest struct {
231 dir string
232 path string
233 *resolverConfig
234 }
235
236 func newResolvConfTest() (*resolvConfTest, error) {
237 dir, err := os.MkdirTemp("", "go-resolvconftest")
238 if err != nil {
239 return nil, err
240 }
241 conf := &resolvConfTest{
242 dir: dir,
243 path: path.Join(dir, "resolv.conf"),
244 resolverConfig: &resolvConf,
245 }
246 conf.initOnce.Do(conf.init)
247 return conf, nil
248 }
249
250 func (conf *resolvConfTest) writeAndUpdate(lines []string) error {
251 f, err := os.OpenFile(conf.path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
252 if err != nil {
253 return err
254 }
255 if _, err := f.WriteString(strings.Join(lines, "\n")); err != nil {
256 f.Close()
257 return err
258 }
259 f.Close()
260 if err := conf.forceUpdate(conf.path, time.Now().Add(time.Hour)); err != nil {
261 return err
262 }
263 return nil
264 }
265
266 func (conf *resolvConfTest) forceUpdate(name string, lastChecked time.Time) error {
267 dnsConf := dnsReadConfig(name)
268 conf.mu.Lock()
269 conf.dnsConfig = dnsConf
270 conf.mu.Unlock()
271 for i := 0; i < 5; i++ {
272 if conf.tryAcquireSema() {
273 conf.lastChecked = lastChecked
274 conf.releaseSema()
275 return nil
276 }
277 }
278 return fmt.Errorf("tryAcquireSema for %s failed", name)
279 }
280
281 func (conf *resolvConfTest) servers() []string {
282 conf.mu.RLock()
283 servers := conf.dnsConfig.servers
284 conf.mu.RUnlock()
285 return servers
286 }
287
288 func (conf *resolvConfTest) teardown() error {
289 err := conf.forceUpdate("/etc/resolv.conf", time.Time{})
290 os.RemoveAll(conf.dir)
291 return err
292 }
293
294 var updateResolvConfTests = []struct {
295 name string
296 lines []string
297 servers []string
298 }{
299 {
300 name: "golang.org",
301 lines: []string{"nameserver 8.8.8.8"},
302 servers: []string{"8.8.8.8:53"},
303 },
304 {
305 name: "",
306 lines: nil,
307 servers: defaultNS,
308 },
309 {
310 name: "www.example.com",
311 lines: []string{"nameserver 8.8.4.4"},
312 servers: []string{"8.8.4.4:53"},
313 },
314 }
315
316 func TestUpdateResolvConf(t *testing.T) {
317 defer dnsWaitGroup.Wait()
318
319 r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
320
321 conf, err := newResolvConfTest()
322 if err != nil {
323 t.Fatal(err)
324 }
325 defer conf.teardown()
326
327 for i, tt := range updateResolvConfTests {
328 if err := conf.writeAndUpdate(tt.lines); err != nil {
329 t.Error(err)
330 continue
331 }
332 if tt.name != "" {
333 var wg sync.WaitGroup
334 const N = 10
335 wg.Add(N)
336 for j := 0; j < N; j++ {
337 go func(name string) {
338 defer wg.Done()
339 ips, err := r.LookupIPAddr(context.Background(), name)
340 if err != nil {
341 t.Error(err)
342 return
343 }
344 if len(ips) == 0 {
345 t.Errorf("no records for %s", name)
346 return
347 }
348 }(tt.name)
349 }
350 wg.Wait()
351 }
352 servers := conf.servers()
353 if !reflect.DeepEqual(servers, tt.servers) {
354 t.Errorf("#%d: got %v; want %v", i, servers, tt.servers)
355 continue
356 }
357 }
358 }
359
360 var goLookupIPWithResolverConfigTests = []struct {
361 name string
362 lines []string
363 error
364 a, aaaa bool
365 }{
366
367 {
368 "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j",
369 []string{
370 "options timeout:1 attempts:1",
371 "nameserver 255.255.255.255",
372 },
373 &DNSError{Name: "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", Server: "255.255.255.255:53", IsTimeout: true},
374 false, false,
375 },
376
377
378 {
379 "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j",
380 []string{
381 "options timeout:3 attempts:1",
382 "nameserver 8.8.8.8",
383 },
384 &DNSError{Name: "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", Server: "8.8.8.8:53", IsTimeout: false},
385 false, false,
386 },
387
388
389 {
390 "ipv4.google.com.",
391 []string{
392 "nameserver 8.8.8.8",
393 "nameserver 2001:4860:4860::8888",
394 },
395 nil,
396 true, false,
397 },
398 {
399 "ipv4.google.com",
400 []string{
401 "domain golang.org",
402 "nameserver 2001:4860:4860::8888",
403 "nameserver 8.8.8.8",
404 },
405 nil,
406 true, false,
407 },
408 {
409 "ipv4.google.com",
410 []string{
411 "search x.golang.org y.golang.org",
412 "nameserver 2001:4860:4860::8888",
413 "nameserver 8.8.8.8",
414 },
415 nil,
416 true, false,
417 },
418
419
420 {
421 "ipv6.google.com.",
422 []string{
423 "nameserver 2001:4860:4860::8888",
424 "nameserver 8.8.8.8",
425 },
426 nil,
427 false, true,
428 },
429 {
430 "ipv6.google.com",
431 []string{
432 "domain golang.org",
433 "nameserver 8.8.8.8",
434 "nameserver 2001:4860:4860::8888",
435 },
436 nil,
437 false, true,
438 },
439 {
440 "ipv6.google.com",
441 []string{
442 "search x.golang.org y.golang.org",
443 "nameserver 8.8.8.8",
444 "nameserver 2001:4860:4860::8888",
445 },
446 nil,
447 false, true,
448 },
449
450
451 {
452 "hostname.as112.net",
453 []string{
454 "domain golang.org",
455 "nameserver 2001:4860:4860::8888",
456 "nameserver 8.8.8.8",
457 },
458 nil,
459 true, true,
460 },
461 {
462 "hostname.as112.net",
463 []string{
464 "search x.golang.org y.golang.org",
465 "nameserver 2001:4860:4860::8888",
466 "nameserver 8.8.8.8",
467 },
468 nil,
469 true, true,
470 },
471 }
472
473 func TestGoLookupIPWithResolverConfig(t *testing.T) {
474 defer dnsWaitGroup.Wait()
475 fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
476 switch s {
477 case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
478 break
479 default:
480 time.Sleep(10 * time.Millisecond)
481 return dnsmessage.Message{}, os.ErrDeadlineExceeded
482 }
483 r := dnsmessage.Message{
484 Header: dnsmessage.Header{
485 ID: q.ID,
486 Response: true,
487 },
488 Questions: q.Questions,
489 }
490 for _, question := range q.Questions {
491 switch question.Type {
492 case dnsmessage.TypeA:
493 switch question.Name.String() {
494 case "hostname.as112.net.":
495 break
496 case "ipv4.google.com.":
497 r.Answers = append(r.Answers, dnsmessage.Resource{
498 Header: dnsmessage.ResourceHeader{
499 Name: q.Questions[0].Name,
500 Type: dnsmessage.TypeA,
501 Class: dnsmessage.ClassINET,
502 Length: 4,
503 },
504 Body: &dnsmessage.AResource{
505 A: TestAddr,
506 },
507 })
508 default:
509
510 }
511 case dnsmessage.TypeAAAA:
512 switch question.Name.String() {
513 case "hostname.as112.net.":
514 break
515 case "ipv6.google.com.":
516 r.Answers = append(r.Answers, dnsmessage.Resource{
517 Header: dnsmessage.ResourceHeader{
518 Name: q.Questions[0].Name,
519 Type: dnsmessage.TypeAAAA,
520 Class: dnsmessage.ClassINET,
521 Length: 16,
522 },
523 Body: &dnsmessage.AAAAResource{
524 AAAA: TestAddr6,
525 },
526 })
527 }
528 }
529 }
530 return r, nil
531 }}
532 r := Resolver{PreferGo: true, Dial: fake.DialContext}
533
534 conf, err := newResolvConfTest()
535 if err != nil {
536 t.Fatal(err)
537 }
538 defer conf.teardown()
539
540 for _, tt := range goLookupIPWithResolverConfigTests {
541 if err := conf.writeAndUpdate(tt.lines); err != nil {
542 t.Error(err)
543 continue
544 }
545 addrs, err := r.LookupIPAddr(context.Background(), tt.name)
546 if err != nil {
547 if err, ok := err.(*DNSError); !ok || tt.error != nil && (err.Name != tt.error.(*DNSError).Name || err.Server != tt.error.(*DNSError).Server || err.IsTimeout != tt.error.(*DNSError).IsTimeout) {
548 t.Errorf("got %v; want %v", err, tt.error)
549 }
550 continue
551 }
552 if len(addrs) == 0 {
553 t.Errorf("no records for %s", tt.name)
554 }
555 if !tt.a && !tt.aaaa && len(addrs) > 0 {
556 t.Errorf("unexpected %v for %s", addrs, tt.name)
557 }
558 for _, addr := range addrs {
559 if !tt.a && addr.IP.To4() != nil {
560 t.Errorf("got %v; must not be IPv4 address", addr)
561 }
562 if !tt.aaaa && addr.IP.To16() != nil && addr.IP.To4() == nil {
563 t.Errorf("got %v; must not be IPv6 address", addr)
564 }
565 }
566 }
567 }
568
569
570 func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
571 defer dnsWaitGroup.Wait()
572
573 fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
574 r := dnsmessage.Message{
575 Header: dnsmessage.Header{
576 ID: q.ID,
577 Response: true,
578 },
579 Questions: q.Questions,
580 }
581 return r, nil
582 }}
583 r := Resolver{PreferGo: true, Dial: fake.DialContext}
584
585
586 conf, err := newResolvConfTest()
587 if err != nil {
588 t.Fatal(err)
589 }
590 defer conf.teardown()
591
592 if err := conf.writeAndUpdate([]string{}); err != nil {
593 t.Fatal(err)
594 }
595
596 defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
597 testHookHostsPath = "testdata/hosts"
598
599 for _, order := range []hostLookupOrder{hostLookupFilesDNS, hostLookupDNSFiles} {
600 name := fmt.Sprintf("order %v", order)
601
602
603 _, _, err := r.goLookupIPCNAMEOrder(context.Background(), "ip", "notarealhost", order)
604 if err == nil {
605 t.Errorf("%s: expected error while looking up name not in hosts file", name)
606 continue
607 }
608
609
610 addrs, _, err := r.goLookupIPCNAMEOrder(context.Background(), "ip", "thor", order)
611 if err != nil {
612 t.Errorf("%s: expected to successfully lookup host entry", name)
613 continue
614 }
615 if len(addrs) != 1 {
616 t.Errorf("%s: expected exactly one result, but got %v", name, addrs)
617 continue
618 }
619 if got, want := addrs[0].String(), "127.1.1.1"; got != want {
620 t.Errorf("%s: address doesn't match expectation. got %v, want %v", name, got, want)
621 }
622 }
623 }
624
625
626
627
628
629 func TestErrorForOriginalNameWhenSearching(t *testing.T) {
630 defer dnsWaitGroup.Wait()
631
632 const fqdn = "doesnotexist.domain"
633
634 conf, err := newResolvConfTest()
635 if err != nil {
636 t.Fatal(err)
637 }
638 defer conf.teardown()
639
640 if err := conf.writeAndUpdate([]string{"search servfail"}); err != nil {
641 t.Fatal(err)
642 }
643
644 fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
645 r := dnsmessage.Message{
646 Header: dnsmessage.Header{
647 ID: q.ID,
648 Response: true,
649 },
650 Questions: q.Questions,
651 }
652
653 switch q.Questions[0].Name.String() {
654 case fqdn + ".servfail.":
655 r.Header.RCode = dnsmessage.RCodeServerFailure
656 default:
657 r.Header.RCode = dnsmessage.RCodeNameError
658 }
659
660 return r, nil
661 }}
662
663 cases := []struct {
664 strictErrors bool
665 wantErr *DNSError
666 }{
667 {true, &DNSError{Name: fqdn, Err: "server misbehaving", IsTemporary: true}},
668 {false, &DNSError{Name: fqdn, Err: errNoSuchHost.Error(), IsNotFound: true}},
669 }
670 for _, tt := range cases {
671 r := Resolver{PreferGo: true, StrictErrors: tt.strictErrors, Dial: fake.DialContext}
672 _, err = r.LookupIPAddr(context.Background(), fqdn)
673 if err == nil {
674 t.Fatal("expected an error")
675 }
676
677 want := tt.wantErr
678 if err, ok := err.(*DNSError); !ok || err.Name != want.Name || err.Err != want.Err || err.IsTemporary != want.IsTemporary {
679 t.Errorf("got %v; want %v", err, want)
680 }
681 }
682 }
683
684
685 func TestIgnoreLameReferrals(t *testing.T) {
686 defer dnsWaitGroup.Wait()
687
688 conf, err := newResolvConfTest()
689 if err != nil {
690 t.Fatal(err)
691 }
692 defer conf.teardown()
693
694 if err := conf.writeAndUpdate([]string{"nameserver 192.0.2.1",
695 "nameserver 192.0.2.2"}); err != nil {
696 t.Fatal(err)
697 }
698
699 fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
700 t.Log(s, q)
701 r := dnsmessage.Message{
702 Header: dnsmessage.Header{
703 ID: q.ID,
704 Response: true,
705 },
706 Questions: q.Questions,
707 }
708
709 if s == "192.0.2.2:53" {
710 r.Header.RecursionAvailable = true
711 if q.Questions[0].Type == dnsmessage.TypeA {
712 r.Answers = []dnsmessage.Resource{
713 {
714 Header: dnsmessage.ResourceHeader{
715 Name: q.Questions[0].Name,
716 Type: dnsmessage.TypeA,
717 Class: dnsmessage.ClassINET,
718 Length: 4,
719 },
720 Body: &dnsmessage.AResource{
721 A: TestAddr,
722 },
723 },
724 }
725 }
726 }
727
728 return r, nil
729 }}
730 r := Resolver{PreferGo: true, Dial: fake.DialContext}
731
732 addrs, err := r.LookupIPAddr(context.Background(), "www.golang.org")
733 if err != nil {
734 t.Fatal(err)
735 }
736
737 if got := len(addrs); got != 1 {
738 t.Fatalf("got %d addresses, want 1", got)
739 }
740
741 if got, want := addrs[0].String(), "192.0.2.1"; got != want {
742 t.Fatalf("got address %v, want %v", got, want)
743 }
744 }
745
746 func BenchmarkGoLookupIP(b *testing.B) {
747 testHookUninstaller.Do(uninstallTestHooks)
748 ctx := context.Background()
749 b.ReportAllocs()
750
751 for i := 0; i < b.N; i++ {
752 goResolver.LookupIPAddr(ctx, "www.example.com")
753 }
754 }
755
756 func BenchmarkGoLookupIPNoSuchHost(b *testing.B) {
757 testHookUninstaller.Do(uninstallTestHooks)
758 ctx := context.Background()
759 b.ReportAllocs()
760
761 for i := 0; i < b.N; i++ {
762 goResolver.LookupIPAddr(ctx, "some.nonexistent")
763 }
764 }
765
766 func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
767 testHookUninstaller.Do(uninstallTestHooks)
768
769 conf, err := newResolvConfTest()
770 if err != nil {
771 b.Fatal(err)
772 }
773 defer conf.teardown()
774
775 lines := []string{
776 "nameserver 203.0.113.254",
777 "nameserver 8.8.8.8",
778 }
779 if err := conf.writeAndUpdate(lines); err != nil {
780 b.Fatal(err)
781 }
782 ctx := context.Background()
783 b.ReportAllocs()
784
785 for i := 0; i < b.N; i++ {
786 goResolver.LookupIPAddr(ctx, "www.example.com")
787 }
788 }
789
790 type fakeDNSServer struct {
791 rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
792 alwaysTCP bool
793 }
794
795 func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
796 if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" {
797 return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil
798 }
799 return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil
800 }
801
802 type fakeDNSConn struct {
803 Conn
804 tcp bool
805 server *fakeDNSServer
806 n string
807 s string
808 q dnsmessage.Message
809 t time.Time
810 buf []byte
811 }
812
813 func (f *fakeDNSConn) Close() error {
814 return nil
815 }
816
817 func (f *fakeDNSConn) Read(b []byte) (int, error) {
818 if len(f.buf) > 0 {
819 n := copy(b, f.buf)
820 f.buf = f.buf[n:]
821 return n, nil
822 }
823
824 resp, err := f.server.rh(f.n, f.s, f.q, f.t)
825 if err != nil {
826 return 0, err
827 }
828
829 bb := make([]byte, 2, 514)
830 bb, err = resp.AppendPack(bb)
831 if err != nil {
832 return 0, fmt.Errorf("cannot marshal DNS message: %v", err)
833 }
834
835 if f.tcp {
836 l := len(bb) - 2
837 bb[0] = byte(l >> 8)
838 bb[1] = byte(l)
839 f.buf = bb
840 return f.Read(b)
841 }
842
843 bb = bb[2:]
844 if len(b) < len(bb) {
845 return 0, errors.New("read would fragment DNS message")
846 }
847
848 copy(b, bb)
849 return len(bb), nil
850 }
851
852 func (f *fakeDNSConn) Write(b []byte) (int, error) {
853 if f.tcp && len(b) >= 2 {
854 b = b[2:]
855 }
856 if f.q.Unpack(b) != nil {
857 return 0, fmt.Errorf("cannot unmarshal DNS message fake %s (%d)", f.n, len(b))
858 }
859 return len(b), nil
860 }
861
862 func (f *fakeDNSConn) SetDeadline(t time.Time) error {
863 f.t = t
864 return nil
865 }
866
867 type fakeDNSPacketConn struct {
868 PacketConn
869 fakeDNSConn
870 }
871
872 func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error {
873 return f.fakeDNSConn.SetDeadline(t)
874 }
875
876 func (f *fakeDNSPacketConn) Close() error {
877 return f.fakeDNSConn.Close()
878 }
879
880
881 func TestIgnoreDNSForgeries(t *testing.T) {
882 c, s := Pipe()
883 go func() {
884 b := make([]byte, maxDNSPacketSize)
885 n, err := s.Read(b)
886 if err != nil {
887 t.Error(err)
888 return
889 }
890
891 var msg dnsmessage.Message
892 if msg.Unpack(b[:n]) != nil {
893 t.Error("invalid DNS query:", err)
894 return
895 }
896
897 s.Write([]byte("garbage DNS response packet"))
898
899 msg.Header.Response = true
900 msg.Header.ID++
901
902 if b, err = msg.Pack(); err != nil {
903 t.Error("failed to pack DNS response:", err)
904 return
905 }
906 s.Write(b)
907
908 msg.Header.ID--
909 msg.Answers = []dnsmessage.Resource{
910 {
911 Header: dnsmessage.ResourceHeader{
912 Name: mustNewName("www.example.com."),
913 Type: dnsmessage.TypeA,
914 Class: dnsmessage.ClassINET,
915 Length: 4,
916 },
917 Body: &dnsmessage.AResource{
918 A: TestAddr,
919 },
920 },
921 }
922
923 b, err = msg.Pack()
924 if err != nil {
925 t.Error("failed to pack DNS response:", err)
926 return
927 }
928 s.Write(b)
929 }()
930
931 msg := dnsmessage.Message{
932 Header: dnsmessage.Header{
933 ID: 42,
934 },
935 Questions: []dnsmessage.Question{
936 {
937 Name: mustNewName("www.example.com."),
938 Type: dnsmessage.TypeA,
939 Class: dnsmessage.ClassINET,
940 },
941 },
942 }
943
944 b, err := msg.Pack()
945 if err != nil {
946 t.Fatal("Pack failed:", err)
947 }
948
949 p, _, err := dnsPacketRoundTrip(c, 42, msg.Questions[0], b)
950 if err != nil {
951 t.Fatalf("dnsPacketRoundTrip failed: %v", err)
952 }
953
954 p.SkipAllQuestions()
955 as, err := p.AllAnswers()
956 if err != nil {
957 t.Fatal("AllAnswers failed:", err)
958 }
959 if got := as[0].Body.(*dnsmessage.AResource).A; got != TestAddr {
960 t.Errorf("got address %v, want %v", got, TestAddr)
961 }
962 }
963
964
965 func TestRetryTimeout(t *testing.T) {
966 defer dnsWaitGroup.Wait()
967
968 conf, err := newResolvConfTest()
969 if err != nil {
970 t.Fatal(err)
971 }
972 defer conf.teardown()
973
974 testConf := []string{
975 "nameserver 192.0.2.1",
976 "nameserver 192.0.2.2",
977 }
978 if err := conf.writeAndUpdate(testConf); err != nil {
979 t.Fatal(err)
980 }
981
982 var deadline0 time.Time
983
984 fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
985 t.Log(s, q, deadline)
986
987 if deadline.IsZero() {
988 t.Error("zero deadline")
989 }
990
991 if s == "192.0.2.1:53" {
992 deadline0 = deadline
993 time.Sleep(10 * time.Millisecond)
994 return dnsmessage.Message{}, os.ErrDeadlineExceeded
995 }
996
997 if deadline.Equal(deadline0) {
998 t.Error("deadline didn't change")
999 }
1000
1001 return mockTXTResponse(q), nil
1002 }}
1003 r := &Resolver{PreferGo: true, Dial: fake.DialContext}
1004
1005 _, err = r.LookupTXT(context.Background(), "www.golang.org")
1006 if err != nil {
1007 t.Fatal(err)
1008 }
1009
1010 if deadline0.IsZero() {
1011 t.Error("deadline0 still zero", deadline0)
1012 }
1013 }
1014
1015 func TestRotate(t *testing.T) {
1016
1017 testRotate(t, false, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.1:53", "192.0.2.1:53"})
1018
1019
1020 testRotate(t, true, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.2:53", "192.0.2.1:53"})
1021 }
1022
1023 func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
1024 defer dnsWaitGroup.Wait()
1025
1026 conf, err := newResolvConfTest()
1027 if err != nil {
1028 t.Fatal(err)
1029 }
1030 defer conf.teardown()
1031
1032 var confLines []string
1033 for _, ns := range nameservers {
1034 confLines = append(confLines, "nameserver "+ns)
1035 }
1036 if rotate {
1037 confLines = append(confLines, "options rotate")
1038 }
1039
1040 if err := conf.writeAndUpdate(confLines); err != nil {
1041 t.Fatal(err)
1042 }
1043
1044 var usedServers []string
1045 fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
1046 usedServers = append(usedServers, s)
1047 return mockTXTResponse(q), nil
1048 }}
1049 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1050
1051
1052 for i := 0; i < len(nameservers)+1; i++ {
1053 if _, err := r.LookupTXT(context.Background(), "www.golang.org"); err != nil {
1054 t.Fatal(err)
1055 }
1056 }
1057
1058 if !reflect.DeepEqual(usedServers, wantServers) {
1059 t.Errorf("rotate=%t got used servers:\n%v\nwant:\n%v", rotate, usedServers, wantServers)
1060 }
1061 }
1062
1063 func mockTXTResponse(q dnsmessage.Message) dnsmessage.Message {
1064 r := dnsmessage.Message{
1065 Header: dnsmessage.Header{
1066 ID: q.ID,
1067 Response: true,
1068 RecursionAvailable: true,
1069 },
1070 Questions: q.Questions,
1071 Answers: []dnsmessage.Resource{
1072 {
1073 Header: dnsmessage.ResourceHeader{
1074 Name: q.Questions[0].Name,
1075 Type: dnsmessage.TypeTXT,
1076 Class: dnsmessage.ClassINET,
1077 },
1078 Body: &dnsmessage.TXTResource{
1079 TXT: []string{"ok"},
1080 },
1081 },
1082 },
1083 }
1084
1085 return r
1086 }
1087
1088
1089
1090 func TestStrictErrorsLookupIP(t *testing.T) {
1091 defer dnsWaitGroup.Wait()
1092
1093 conf, err := newResolvConfTest()
1094 if err != nil {
1095 t.Fatal(err)
1096 }
1097 defer conf.teardown()
1098
1099 confData := []string{
1100 "nameserver 192.0.2.53",
1101 "search x.golang.org y.golang.org",
1102 }
1103 if err := conf.writeAndUpdate(confData); err != nil {
1104 t.Fatal(err)
1105 }
1106
1107 const name = "test-issue19592"
1108 const server = "192.0.2.53:53"
1109 const searchX = "test-issue19592.x.golang.org."
1110 const searchY = "test-issue19592.y.golang.org."
1111 const ip4 = "192.0.2.1"
1112 const ip6 = "2001:db8::1"
1113
1114 type resolveWhichEnum int
1115 const (
1116 resolveOK resolveWhichEnum = iota
1117 resolveOpError
1118 resolveServfail
1119 resolveTimeout
1120 )
1121
1122 makeTempError := func(err string) error {
1123 return &DNSError{
1124 Err: err,
1125 Name: name,
1126 Server: server,
1127 IsTemporary: true,
1128 }
1129 }
1130 makeTimeout := func() error {
1131 return &DNSError{
1132 Err: os.ErrDeadlineExceeded.Error(),
1133 Name: name,
1134 Server: server,
1135 IsTimeout: true,
1136 }
1137 }
1138 makeNxDomain := func() error {
1139 return &DNSError{
1140 Err: errNoSuchHost.Error(),
1141 Name: name,
1142 Server: server,
1143 IsNotFound: true,
1144 }
1145 }
1146
1147 cases := []struct {
1148 desc string
1149 resolveWhich func(quest dnsmessage.Question) resolveWhichEnum
1150 wantStrictErr error
1151 wantLaxErr error
1152 wantIPs []string
1153 }{
1154 {
1155 desc: "No errors",
1156 resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
1157 return resolveOK
1158 },
1159 wantIPs: []string{ip4, ip6},
1160 },
1161 {
1162 desc: "searchX error fails in strict mode",
1163 resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
1164 if quest.Name.String() == searchX {
1165 return resolveTimeout
1166 }
1167 return resolveOK
1168 },
1169 wantStrictErr: makeTimeout(),
1170 wantIPs: []string{ip4, ip6},
1171 },
1172 {
1173 desc: "searchX IPv4-only timeout fails in strict mode",
1174 resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
1175 if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeA {
1176 return resolveTimeout
1177 }
1178 return resolveOK
1179 },
1180 wantStrictErr: makeTimeout(),
1181 wantIPs: []string{ip4, ip6},
1182 },
1183 {
1184 desc: "searchX IPv6-only servfail fails in strict mode",
1185 resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
1186 if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeAAAA {
1187 return resolveServfail
1188 }
1189 return resolveOK
1190 },
1191 wantStrictErr: makeTempError("server misbehaving"),
1192 wantIPs: []string{ip4, ip6},
1193 },
1194 {
1195 desc: "searchY error always fails",
1196 resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
1197 if quest.Name.String() == searchY {
1198 return resolveTimeout
1199 }
1200 return resolveOK
1201 },
1202 wantStrictErr: makeTimeout(),
1203 wantLaxErr: makeNxDomain(),
1204 },
1205 {
1206 desc: "searchY IPv4-only socket error fails in strict mode",
1207 resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
1208 if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeA {
1209 return resolveOpError
1210 }
1211 return resolveOK
1212 },
1213 wantStrictErr: makeTempError("write: socket on fire"),
1214 wantIPs: []string{ip6},
1215 },
1216 {
1217 desc: "searchY IPv6-only timeout fails in strict mode",
1218 resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
1219 if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeAAAA {
1220 return resolveTimeout
1221 }
1222 return resolveOK
1223 },
1224 wantStrictErr: makeTimeout(),
1225 wantIPs: []string{ip4},
1226 },
1227 }
1228
1229 for i, tt := range cases {
1230 fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
1231 t.Log(s, q)
1232
1233 switch tt.resolveWhich(q.Questions[0]) {
1234 case resolveOK:
1235
1236 case resolveOpError:
1237 return dnsmessage.Message{}, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
1238 case resolveServfail:
1239 return dnsmessage.Message{
1240 Header: dnsmessage.Header{
1241 ID: q.ID,
1242 Response: true,
1243 RCode: dnsmessage.RCodeServerFailure,
1244 },
1245 Questions: q.Questions,
1246 }, nil
1247 case resolveTimeout:
1248 return dnsmessage.Message{}, os.ErrDeadlineExceeded
1249 default:
1250 t.Fatal("Impossible resolveWhich")
1251 }
1252
1253 switch q.Questions[0].Name.String() {
1254 case searchX, name + ".":
1255
1256 return dnsmessage.Message{
1257 Header: dnsmessage.Header{
1258 ID: q.ID,
1259 Response: true,
1260 RCode: dnsmessage.RCodeNameError,
1261 },
1262 Questions: q.Questions,
1263 }, nil
1264 case searchY:
1265
1266 default:
1267 return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
1268 }
1269
1270 r := dnsmessage.Message{
1271 Header: dnsmessage.Header{
1272 ID: q.ID,
1273 Response: true,
1274 },
1275 Questions: q.Questions,
1276 }
1277 switch q.Questions[0].Type {
1278 case dnsmessage.TypeA:
1279 r.Answers = []dnsmessage.Resource{
1280 {
1281 Header: dnsmessage.ResourceHeader{
1282 Name: q.Questions[0].Name,
1283 Type: dnsmessage.TypeA,
1284 Class: dnsmessage.ClassINET,
1285 Length: 4,
1286 },
1287 Body: &dnsmessage.AResource{
1288 A: TestAddr,
1289 },
1290 },
1291 }
1292 case dnsmessage.TypeAAAA:
1293 r.Answers = []dnsmessage.Resource{
1294 {
1295 Header: dnsmessage.ResourceHeader{
1296 Name: q.Questions[0].Name,
1297 Type: dnsmessage.TypeAAAA,
1298 Class: dnsmessage.ClassINET,
1299 Length: 16,
1300 },
1301 Body: &dnsmessage.AAAAResource{
1302 AAAA: TestAddr6,
1303 },
1304 },
1305 }
1306 default:
1307 return dnsmessage.Message{}, fmt.Errorf("Unexpected Type: %v", q.Questions[0].Type)
1308 }
1309 return r, nil
1310 }}
1311
1312 for _, strict := range []bool{true, false} {
1313 r := Resolver{PreferGo: true, StrictErrors: strict, Dial: fake.DialContext}
1314 ips, err := r.LookupIPAddr(context.Background(), name)
1315
1316 var wantErr error
1317 if strict {
1318 wantErr = tt.wantStrictErr
1319 } else {
1320 wantErr = tt.wantLaxErr
1321 }
1322 if !reflect.DeepEqual(err, wantErr) {
1323 t.Errorf("#%d (%s) strict=%v: got err %#v; want %#v", i, tt.desc, strict, err, wantErr)
1324 }
1325
1326 gotIPs := map[string]struct{}{}
1327 for _, ip := range ips {
1328 gotIPs[ip.String()] = struct{}{}
1329 }
1330 wantIPs := map[string]struct{}{}
1331 if wantErr == nil {
1332 for _, ip := range tt.wantIPs {
1333 wantIPs[ip] = struct{}{}
1334 }
1335 }
1336 if !reflect.DeepEqual(gotIPs, wantIPs) {
1337 t.Errorf("#%d (%s) strict=%v: got ips %v; want %v", i, tt.desc, strict, gotIPs, wantIPs)
1338 }
1339 }
1340 }
1341 }
1342
1343
1344
1345 func TestStrictErrorsLookupTXT(t *testing.T) {
1346 defer dnsWaitGroup.Wait()
1347
1348 conf, err := newResolvConfTest()
1349 if err != nil {
1350 t.Fatal(err)
1351 }
1352 defer conf.teardown()
1353
1354 confData := []string{
1355 "nameserver 192.0.2.53",
1356 "search x.golang.org y.golang.org",
1357 }
1358 if err := conf.writeAndUpdate(confData); err != nil {
1359 t.Fatal(err)
1360 }
1361
1362 const name = "test"
1363 const server = "192.0.2.53:53"
1364 const searchX = "test.x.golang.org."
1365 const searchY = "test.y.golang.org."
1366 const txt = "Hello World"
1367
1368 fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
1369 t.Log(s, q)
1370
1371 switch q.Questions[0].Name.String() {
1372 case searchX:
1373 return dnsmessage.Message{}, os.ErrDeadlineExceeded
1374 case searchY:
1375 return mockTXTResponse(q), nil
1376 default:
1377 return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
1378 }
1379 }}
1380
1381 for _, strict := range []bool{true, false} {
1382 r := Resolver{StrictErrors: strict, Dial: fake.DialContext}
1383 p, _, err := r.lookup(context.Background(), name, dnsmessage.TypeTXT)
1384 var wantErr error
1385 var wantRRs int
1386 if strict {
1387 wantErr = &DNSError{
1388 Err: os.ErrDeadlineExceeded.Error(),
1389 Name: name,
1390 Server: server,
1391 IsTimeout: true,
1392 }
1393 } else {
1394 wantRRs = 1
1395 }
1396 if !reflect.DeepEqual(err, wantErr) {
1397 t.Errorf("strict=%v: got err %#v; want %#v", strict, err, wantErr)
1398 }
1399 a, err := p.AllAnswers()
1400 if err != nil {
1401 a = nil
1402 }
1403 if len(a) != wantRRs {
1404 t.Errorf("strict=%v: got %v; want %v", strict, len(a), wantRRs)
1405 }
1406 }
1407 }
1408
1409
1410
1411 func TestDNSGoroutineRace(t *testing.T) {
1412 defer dnsWaitGroup.Wait()
1413
1414 fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
1415 time.Sleep(10 * time.Microsecond)
1416 return dnsmessage.Message{}, os.ErrDeadlineExceeded
1417 }}
1418 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1419
1420
1421
1422
1423 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Microsecond)
1424 defer cancel()
1425 _, err := r.LookupIPAddr(ctx, "where.are.they.now")
1426 if err == nil {
1427 t.Fatal("fake DNS lookup unexpectedly succeeded")
1428 }
1429 }
1430
1431 func lookupWithFake(fake fakeDNSServer, name string, typ dnsmessage.Type) error {
1432 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1433
1434 resolvConf.mu.RLock()
1435 conf := resolvConf.dnsConfig
1436 resolvConf.mu.RUnlock()
1437
1438 ctx, cancel := context.WithCancel(context.Background())
1439 defer cancel()
1440
1441 _, _, err := r.tryOneName(ctx, conf, name, typ)
1442 return err
1443 }
1444
1445
1446
1447 func TestIssue8434(t *testing.T) {
1448 err := lookupWithFake(fakeDNSServer{
1449 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1450 return dnsmessage.Message{
1451 Header: dnsmessage.Header{
1452 ID: q.ID,
1453 Response: true,
1454 RCode: dnsmessage.RCodeServerFailure,
1455 },
1456 Questions: q.Questions,
1457 }, nil
1458 },
1459 }, "golang.org.", dnsmessage.TypeALL)
1460 if err == nil {
1461 t.Fatal("expected an error")
1462 }
1463 if ne, ok := err.(Error); !ok {
1464 t.Fatalf("err = %#v; wanted something supporting net.Error", err)
1465 } else if !ne.Temporary() {
1466 t.Fatalf("Temporary = false for err = %#v; want Temporary == true", err)
1467 }
1468 if de, ok := err.(*DNSError); !ok {
1469 t.Fatalf("err = %#v; wanted a *net.DNSError", err)
1470 } else if !de.IsTemporary {
1471 t.Fatalf("IsTemporary = false for err = %#v; want IsTemporary == true", err)
1472 }
1473 }
1474
1475 func TestIssueNoSuchHostExists(t *testing.T) {
1476 err := lookupWithFake(fakeDNSServer{
1477 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1478 return dnsmessage.Message{
1479 Header: dnsmessage.Header{
1480 ID: q.ID,
1481 Response: true,
1482 RCode: dnsmessage.RCodeNameError,
1483 },
1484 Questions: q.Questions,
1485 }, nil
1486 },
1487 }, "golang.org.", dnsmessage.TypeALL)
1488 if err == nil {
1489 t.Fatal("expected an error")
1490 }
1491 if _, ok := err.(Error); !ok {
1492 t.Fatalf("err = %#v; wanted something supporting net.Error", err)
1493 }
1494 if de, ok := err.(*DNSError); !ok {
1495 t.Fatalf("err = %#v; wanted a *net.DNSError", err)
1496 } else if !de.IsNotFound {
1497 t.Fatalf("IsNotFound = false for err = %#v; want IsNotFound == true", err)
1498 }
1499 }
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510 func TestNoSuchHost(t *testing.T) {
1511 tests := []struct {
1512 name string
1513 f func(string, string, dnsmessage.Message, time.Time) (dnsmessage.Message, error)
1514 }{
1515 {
1516 "NXDOMAIN",
1517 func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1518 return dnsmessage.Message{
1519 Header: dnsmessage.Header{
1520 ID: q.ID,
1521 Response: true,
1522 RCode: dnsmessage.RCodeNameError,
1523 RecursionAvailable: false,
1524 },
1525 Questions: q.Questions,
1526 }, nil
1527 },
1528 },
1529 {
1530 "no answers",
1531 func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1532 return dnsmessage.Message{
1533 Header: dnsmessage.Header{
1534 ID: q.ID,
1535 Response: true,
1536 RCode: dnsmessage.RCodeSuccess,
1537 RecursionAvailable: false,
1538 Authoritative: true,
1539 },
1540 Questions: q.Questions,
1541 }, nil
1542 },
1543 },
1544 }
1545
1546 for _, test := range tests {
1547 t.Run(test.name, func(t *testing.T) {
1548 lookups := 0
1549 err := lookupWithFake(fakeDNSServer{
1550 rh: func(n, s string, q dnsmessage.Message, d time.Time) (dnsmessage.Message, error) {
1551 lookups++
1552 return test.f(n, s, q, d)
1553 },
1554 }, ".", dnsmessage.TypeALL)
1555
1556 if lookups != 1 {
1557 t.Errorf("got %d lookups, wanted 1", lookups)
1558 }
1559
1560 if err == nil {
1561 t.Fatal("expected an error")
1562 }
1563 de, ok := err.(*DNSError)
1564 if !ok {
1565 t.Fatalf("err = %#v; wanted a *net.DNSError", err)
1566 }
1567 if de.Err != errNoSuchHost.Error() {
1568 t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
1569 }
1570 if !de.IsNotFound {
1571 t.Fatalf("IsNotFound = %v wanted true", de.IsNotFound)
1572 }
1573 })
1574 }
1575 }
1576
1577
1578
1579 func TestDNSDialTCP(t *testing.T) {
1580 fake := fakeDNSServer{
1581 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1582 r := dnsmessage.Message{
1583 Header: dnsmessage.Header{
1584 ID: q.Header.ID,
1585 Response: true,
1586 RCode: dnsmessage.RCodeSuccess,
1587 },
1588 Questions: q.Questions,
1589 }
1590 return r, nil
1591 },
1592 alwaysTCP: true,
1593 }
1594 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1595 ctx := context.Background()
1596 _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second, useUDPOrTCP)
1597 if err != nil {
1598 t.Fatal("exhange failed:", err)
1599 }
1600 }
1601
1602
1603 func TestTXTRecordTwoStrings(t *testing.T) {
1604 fake := fakeDNSServer{
1605 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1606 r := dnsmessage.Message{
1607 Header: dnsmessage.Header{
1608 ID: q.Header.ID,
1609 Response: true,
1610 RCode: dnsmessage.RCodeSuccess,
1611 },
1612 Questions: q.Questions,
1613 Answers: []dnsmessage.Resource{
1614 {
1615 Header: dnsmessage.ResourceHeader{
1616 Name: q.Questions[0].Name,
1617 Type: dnsmessage.TypeA,
1618 Class: dnsmessage.ClassINET,
1619 },
1620 Body: &dnsmessage.TXTResource{
1621 TXT: []string{"string1 ", "string2"},
1622 },
1623 },
1624 {
1625 Header: dnsmessage.ResourceHeader{
1626 Name: q.Questions[0].Name,
1627 Type: dnsmessage.TypeA,
1628 Class: dnsmessage.ClassINET,
1629 },
1630 Body: &dnsmessage.TXTResource{
1631 TXT: []string{"onestring"},
1632 },
1633 },
1634 },
1635 }
1636 return r, nil
1637 },
1638 }
1639 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1640 txt, err := r.lookupTXT(context.Background(), "golang.org")
1641 if err != nil {
1642 t.Fatal("LookupTXT failed:", err)
1643 }
1644 if want := 2; len(txt) != want {
1645 t.Fatalf("len(txt), got %d, want %d", len(txt), want)
1646 }
1647 if want := "string1 string2"; txt[0] != want {
1648 t.Errorf("txt[0], got %q, want %q", txt[0], want)
1649 }
1650 if want := "onestring"; txt[1] != want {
1651 t.Errorf("txt[1], got %q, want %q", txt[1], want)
1652 }
1653 }
1654
1655
1656
1657 func TestSingleRequestLookup(t *testing.T) {
1658 defer dnsWaitGroup.Wait()
1659 var (
1660 firstcalled int32
1661 ipv4 int32 = 1
1662 ipv6 int32 = 2
1663 )
1664 fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1665 r := dnsmessage.Message{
1666 Header: dnsmessage.Header{
1667 ID: q.ID,
1668 Response: true,
1669 },
1670 Questions: q.Questions,
1671 }
1672 for _, question := range q.Questions {
1673 switch question.Type {
1674 case dnsmessage.TypeA:
1675 if question.Name.String() == "slowipv4.example.net." {
1676 time.Sleep(10 * time.Millisecond)
1677 }
1678 if !atomic.CompareAndSwapInt32(&firstcalled, 0, ipv4) {
1679 t.Errorf("the A query was received after the AAAA query !")
1680 }
1681 r.Answers = append(r.Answers, dnsmessage.Resource{
1682 Header: dnsmessage.ResourceHeader{
1683 Name: q.Questions[0].Name,
1684 Type: dnsmessage.TypeA,
1685 Class: dnsmessage.ClassINET,
1686 Length: 4,
1687 },
1688 Body: &dnsmessage.AResource{
1689 A: TestAddr,
1690 },
1691 })
1692 case dnsmessage.TypeAAAA:
1693 atomic.CompareAndSwapInt32(&firstcalled, 0, ipv6)
1694 r.Answers = append(r.Answers, dnsmessage.Resource{
1695 Header: dnsmessage.ResourceHeader{
1696 Name: q.Questions[0].Name,
1697 Type: dnsmessage.TypeAAAA,
1698 Class: dnsmessage.ClassINET,
1699 Length: 16,
1700 },
1701 Body: &dnsmessage.AAAAResource{
1702 AAAA: TestAddr6,
1703 },
1704 })
1705 }
1706 }
1707 return r, nil
1708 }}
1709 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1710
1711 conf, err := newResolvConfTest()
1712 if err != nil {
1713 t.Fatal(err)
1714 }
1715 defer conf.teardown()
1716 if err := conf.writeAndUpdate([]string{"options single-request"}); err != nil {
1717 t.Fatal(err)
1718 }
1719 for _, name := range []string{"hostname.example.net", "slowipv4.example.net"} {
1720 firstcalled = 0
1721 _, err := r.LookupIPAddr(context.Background(), name)
1722 if err != nil {
1723 t.Error(err)
1724 }
1725 }
1726 }
1727
1728
1729 func TestDNSUseTCP(t *testing.T) {
1730 fake := fakeDNSServer{
1731 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1732 r := dnsmessage.Message{
1733 Header: dnsmessage.Header{
1734 ID: q.Header.ID,
1735 Response: true,
1736 RCode: dnsmessage.RCodeSuccess,
1737 },
1738 Questions: q.Questions,
1739 }
1740 if n == "udp" {
1741 t.Fatal("udp protocol was used instead of tcp")
1742 }
1743 return r, nil
1744 },
1745 }
1746 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1747 ctx, cancel := context.WithCancel(context.Background())
1748 defer cancel()
1749 _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second, useTCPOnly)
1750 if err != nil {
1751 t.Fatal("exchange failed:", err)
1752 }
1753 }
1754
1755
1756 func TestPTRandNonPTR(t *testing.T) {
1757 fake := fakeDNSServer{
1758 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1759 r := dnsmessage.Message{
1760 Header: dnsmessage.Header{
1761 ID: q.Header.ID,
1762 Response: true,
1763 RCode: dnsmessage.RCodeSuccess,
1764 },
1765 Questions: q.Questions,
1766 Answers: []dnsmessage.Resource{
1767 {
1768 Header: dnsmessage.ResourceHeader{
1769 Name: q.Questions[0].Name,
1770 Type: dnsmessage.TypePTR,
1771 Class: dnsmessage.ClassINET,
1772 },
1773 Body: &dnsmessage.PTRResource{
1774 PTR: dnsmessage.MustNewName("golang.org."),
1775 },
1776 },
1777 {
1778 Header: dnsmessage.ResourceHeader{
1779 Name: q.Questions[0].Name,
1780 Type: dnsmessage.TypeTXT,
1781 Class: dnsmessage.ClassINET,
1782 },
1783 Body: &dnsmessage.TXTResource{
1784 TXT: []string{"PTR 8 6 60 ..."},
1785 },
1786 },
1787 },
1788 }
1789 return r, nil
1790 },
1791 }
1792 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1793 names, err := r.lookupAddr(context.Background(), "192.0.2.123")
1794 if err != nil {
1795 t.Fatalf("LookupAddr: %v", err)
1796 }
1797 if want := []string{"golang.org."}; !reflect.DeepEqual(names, want) {
1798 t.Errorf("names = %q; want %q", names, want)
1799 }
1800 }
1801
1802 func TestCVE202133195(t *testing.T) {
1803 fake := fakeDNSServer{
1804 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1805 r := dnsmessage.Message{
1806 Header: dnsmessage.Header{
1807 ID: q.Header.ID,
1808 Response: true,
1809 RCode: dnsmessage.RCodeSuccess,
1810 RecursionAvailable: true,
1811 },
1812 Questions: q.Questions,
1813 }
1814 switch q.Questions[0].Type {
1815 case dnsmessage.TypeCNAME:
1816 r.Answers = []dnsmessage.Resource{}
1817 case dnsmessage.TypeA:
1818 r.Answers = append(r.Answers,
1819 dnsmessage.Resource{
1820 Header: dnsmessage.ResourceHeader{
1821 Name: dnsmessage.MustNewName("<html>.golang.org."),
1822 Type: dnsmessage.TypeA,
1823 Class: dnsmessage.ClassINET,
1824 Length: 4,
1825 },
1826 Body: &dnsmessage.AResource{
1827 A: TestAddr,
1828 },
1829 },
1830 )
1831 case dnsmessage.TypeSRV:
1832 n := q.Questions[0].Name
1833 if n.String() == "_hdr._tcp.golang.org." {
1834 n = dnsmessage.MustNewName("<html>.golang.org.")
1835 }
1836 r.Answers = append(r.Answers,
1837 dnsmessage.Resource{
1838 Header: dnsmessage.ResourceHeader{
1839 Name: n,
1840 Type: dnsmessage.TypeSRV,
1841 Class: dnsmessage.ClassINET,
1842 Length: 4,
1843 },
1844 Body: &dnsmessage.SRVResource{
1845 Target: dnsmessage.MustNewName("<html>.golang.org."),
1846 },
1847 },
1848 dnsmessage.Resource{
1849 Header: dnsmessage.ResourceHeader{
1850 Name: n,
1851 Type: dnsmessage.TypeSRV,
1852 Class: dnsmessage.ClassINET,
1853 Length: 4,
1854 },
1855 Body: &dnsmessage.SRVResource{
1856 Target: dnsmessage.MustNewName("good.golang.org."),
1857 },
1858 },
1859 )
1860 case dnsmessage.TypeMX:
1861 r.Answers = append(r.Answers,
1862 dnsmessage.Resource{
1863 Header: dnsmessage.ResourceHeader{
1864 Name: dnsmessage.MustNewName("<html>.golang.org."),
1865 Type: dnsmessage.TypeMX,
1866 Class: dnsmessage.ClassINET,
1867 Length: 4,
1868 },
1869 Body: &dnsmessage.MXResource{
1870 MX: dnsmessage.MustNewName("<html>.golang.org."),
1871 },
1872 },
1873 dnsmessage.Resource{
1874 Header: dnsmessage.ResourceHeader{
1875 Name: dnsmessage.MustNewName("good.golang.org."),
1876 Type: dnsmessage.TypeMX,
1877 Class: dnsmessage.ClassINET,
1878 Length: 4,
1879 },
1880 Body: &dnsmessage.MXResource{
1881 MX: dnsmessage.MustNewName("good.golang.org."),
1882 },
1883 },
1884 )
1885 case dnsmessage.TypeNS:
1886 r.Answers = append(r.Answers,
1887 dnsmessage.Resource{
1888 Header: dnsmessage.ResourceHeader{
1889 Name: dnsmessage.MustNewName("<html>.golang.org."),
1890 Type: dnsmessage.TypeNS,
1891 Class: dnsmessage.ClassINET,
1892 Length: 4,
1893 },
1894 Body: &dnsmessage.NSResource{
1895 NS: dnsmessage.MustNewName("<html>.golang.org."),
1896 },
1897 },
1898 dnsmessage.Resource{
1899 Header: dnsmessage.ResourceHeader{
1900 Name: dnsmessage.MustNewName("good.golang.org."),
1901 Type: dnsmessage.TypeNS,
1902 Class: dnsmessage.ClassINET,
1903 Length: 4,
1904 },
1905 Body: &dnsmessage.NSResource{
1906 NS: dnsmessage.MustNewName("good.golang.org."),
1907 },
1908 },
1909 )
1910 case dnsmessage.TypePTR:
1911 r.Answers = append(r.Answers,
1912 dnsmessage.Resource{
1913 Header: dnsmessage.ResourceHeader{
1914 Name: dnsmessage.MustNewName("<html>.golang.org."),
1915 Type: dnsmessage.TypePTR,
1916 Class: dnsmessage.ClassINET,
1917 Length: 4,
1918 },
1919 Body: &dnsmessage.PTRResource{
1920 PTR: dnsmessage.MustNewName("<html>.golang.org."),
1921 },
1922 },
1923 dnsmessage.Resource{
1924 Header: dnsmessage.ResourceHeader{
1925 Name: dnsmessage.MustNewName("good.golang.org."),
1926 Type: dnsmessage.TypePTR,
1927 Class: dnsmessage.ClassINET,
1928 Length: 4,
1929 },
1930 Body: &dnsmessage.PTRResource{
1931 PTR: dnsmessage.MustNewName("good.golang.org."),
1932 },
1933 },
1934 )
1935 }
1936 return r, nil
1937 },
1938 }
1939
1940 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1941
1942 originalDefault := DefaultResolver
1943 DefaultResolver = &r
1944 defer func() { DefaultResolver = originalDefault }()
1945
1946 defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
1947 testHookHostsPath = "testdata/hosts"
1948
1949 tests := []struct {
1950 name string
1951 f func(*testing.T)
1952 }{
1953 {
1954 name: "CNAME",
1955 f: func(t *testing.T) {
1956 expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
1957 _, err := r.LookupCNAME(context.Background(), "golang.org")
1958 if err.Error() != expectedErr.Error() {
1959 t.Fatalf("unexpected error: %s", err)
1960 }
1961 _, err = LookupCNAME("golang.org")
1962 if err.Error() != expectedErr.Error() {
1963 t.Fatalf("unexpected error: %s", err)
1964 }
1965 },
1966 },
1967 {
1968 name: "SRV (bad record)",
1969 f: func(t *testing.T) {
1970 expected := []*SRV{
1971 {
1972 Target: "good.golang.org.",
1973 },
1974 }
1975 expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
1976 _, records, err := r.LookupSRV(context.Background(), "target", "tcp", "golang.org")
1977 if err.Error() != expectedErr.Error() {
1978 t.Fatalf("unexpected error: %s", err)
1979 }
1980 if !reflect.DeepEqual(records, expected) {
1981 t.Error("Unexpected record set")
1982 }
1983 _, records, err = LookupSRV("target", "tcp", "golang.org")
1984 if err.Error() != expectedErr.Error() {
1985 t.Errorf("unexpected error: %s", err)
1986 }
1987 if !reflect.DeepEqual(records, expected) {
1988 t.Error("Unexpected record set")
1989 }
1990 },
1991 },
1992 {
1993 name: "SRV (bad header)",
1994 f: func(t *testing.T) {
1995 _, _, err := r.LookupSRV(context.Background(), "hdr", "tcp", "golang.org.")
1996 if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected {
1997 t.Errorf("Resolver.LookupSRV returned unexpected error, got %q, want %q", err, expected)
1998 }
1999 _, _, err = LookupSRV("hdr", "tcp", "golang.org.")
2000 if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected {
2001 t.Errorf("LookupSRV returned unexpected error, got %q, want %q", err, expected)
2002 }
2003 },
2004 },
2005 {
2006 name: "MX",
2007 f: func(t *testing.T) {
2008 expected := []*MX{
2009 {
2010 Host: "good.golang.org.",
2011 },
2012 }
2013 expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
2014 records, err := r.LookupMX(context.Background(), "golang.org")
2015 if err.Error() != expectedErr.Error() {
2016 t.Fatalf("unexpected error: %s", err)
2017 }
2018 if !reflect.DeepEqual(records, expected) {
2019 t.Error("Unexpected record set")
2020 }
2021 records, err = LookupMX("golang.org")
2022 if err.Error() != expectedErr.Error() {
2023 t.Fatalf("unexpected error: %s", err)
2024 }
2025 if !reflect.DeepEqual(records, expected) {
2026 t.Error("Unexpected record set")
2027 }
2028 },
2029 },
2030 {
2031 name: "NS",
2032 f: func(t *testing.T) {
2033 expected := []*NS{
2034 {
2035 Host: "good.golang.org.",
2036 },
2037 }
2038 expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
2039 records, err := r.LookupNS(context.Background(), "golang.org")
2040 if err.Error() != expectedErr.Error() {
2041 t.Fatalf("unexpected error: %s", err)
2042 }
2043 if !reflect.DeepEqual(records, expected) {
2044 t.Error("Unexpected record set")
2045 }
2046 records, err = LookupNS("golang.org")
2047 if err.Error() != expectedErr.Error() {
2048 t.Fatalf("unexpected error: %s", err)
2049 }
2050 if !reflect.DeepEqual(records, expected) {
2051 t.Error("Unexpected record set")
2052 }
2053 },
2054 },
2055 {
2056 name: "Addr",
2057 f: func(t *testing.T) {
2058 expected := []string{"good.golang.org."}
2059 expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "192.0.2.42"}
2060 records, err := r.LookupAddr(context.Background(), "192.0.2.42")
2061 if err.Error() != expectedErr.Error() {
2062 t.Fatalf("unexpected error: %s", err)
2063 }
2064 if !reflect.DeepEqual(records, expected) {
2065 t.Error("Unexpected record set")
2066 }
2067 records, err = LookupAddr("192.0.2.42")
2068 if err.Error() != expectedErr.Error() {
2069 t.Fatalf("unexpected error: %s", err)
2070 }
2071 if !reflect.DeepEqual(records, expected) {
2072 t.Error("Unexpected record set")
2073 }
2074 },
2075 },
2076 }
2077
2078 for _, tc := range tests {
2079 t.Run(tc.name, tc.f)
2080 }
2081
2082 }
2083
2084 func TestNullMX(t *testing.T) {
2085 fake := fakeDNSServer{
2086 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
2087 r := dnsmessage.Message{
2088 Header: dnsmessage.Header{
2089 ID: q.Header.ID,
2090 Response: true,
2091 RCode: dnsmessage.RCodeSuccess,
2092 },
2093 Questions: q.Questions,
2094 Answers: []dnsmessage.Resource{
2095 {
2096 Header: dnsmessage.ResourceHeader{
2097 Name: q.Questions[0].Name,
2098 Type: dnsmessage.TypeMX,
2099 Class: dnsmessage.ClassINET,
2100 },
2101 Body: &dnsmessage.MXResource{
2102 MX: dnsmessage.MustNewName("."),
2103 },
2104 },
2105 },
2106 }
2107 return r, nil
2108 },
2109 }
2110 r := Resolver{PreferGo: true, Dial: fake.DialContext}
2111 rrset, err := r.LookupMX(context.Background(), "golang.org")
2112 if err != nil {
2113 t.Fatalf("LookupMX: %v", err)
2114 }
2115 if want := []*MX{&MX{Host: "."}}; !reflect.DeepEqual(rrset, want) {
2116 records := []string{}
2117 for _, rr := range rrset {
2118 records = append(records, fmt.Sprintf("%v", rr))
2119 }
2120 t.Errorf("records = [%v]; want [%v]", strings.Join(records, " "), want[0])
2121 }
2122 }
2123
2124 func TestRootNS(t *testing.T) {
2125
2126 fake := fakeDNSServer{
2127 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
2128 r := dnsmessage.Message{
2129 Header: dnsmessage.Header{
2130 ID: q.Header.ID,
2131 Response: true,
2132 RCode: dnsmessage.RCodeSuccess,
2133 },
2134 Questions: q.Questions,
2135 Answers: []dnsmessage.Resource{
2136 {
2137 Header: dnsmessage.ResourceHeader{
2138 Name: q.Questions[0].Name,
2139 Type: dnsmessage.TypeNS,
2140 Class: dnsmessage.ClassINET,
2141 },
2142 Body: &dnsmessage.NSResource{
2143 NS: dnsmessage.MustNewName("i.root-servers.net."),
2144 },
2145 },
2146 },
2147 }
2148 return r, nil
2149 },
2150 }
2151 r := Resolver{PreferGo: true, Dial: fake.DialContext}
2152 rrset, err := r.LookupNS(context.Background(), ".")
2153 if err != nil {
2154 t.Fatalf("LookupNS: %v", err)
2155 }
2156 if want := []*NS{&NS{Host: "i.root-servers.net."}}; !reflect.DeepEqual(rrset, want) {
2157 records := []string{}
2158 for _, rr := range rrset {
2159 records = append(records, fmt.Sprintf("%v", rr))
2160 }
2161 t.Errorf("records = [%v]; want [%v]", strings.Join(records, " "), want[0])
2162 }
2163 }
2164
View as plain text