Source file
src/net/dnsclient_unix.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package net
16
17 import (
18 "context"
19 "errors"
20 "internal/itoa"
21 "io"
22 "os"
23 "sync"
24 "time"
25
26 "golang.org/x/net/dns/dnsmessage"
27 )
28
29 const (
30
31 useTCPOnly = true
32 useUDPOrTCP = false
33
34
35
36 maxDNSPacketSize = 1232
37 )
38
39 var (
40 errLameReferral = errors.New("lame referral")
41 errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message")
42 errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message")
43 errServerMisbehaving = errors.New("server misbehaving")
44 errInvalidDNSResponse = errors.New("invalid DNS response")
45 errNoAnswerFromDNSServer = errors.New("no answer from DNS server")
46
47
48
49
50 errServerTemporarilyMisbehaving = errors.New("server misbehaving")
51 )
52
53 func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
54 id = uint16(randInt())
55 b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
56 b.EnableCompression()
57 if err := b.StartQuestions(); err != nil {
58 return 0, nil, nil, err
59 }
60 if err := b.Question(q); err != nil {
61 return 0, nil, nil, err
62 }
63 tcpReq, err = b.Finish()
64 udpReq = tcpReq[2:]
65 l := len(tcpReq) - 2
66 tcpReq[0] = byte(l >> 8)
67 tcpReq[1] = byte(l)
68 return id, udpReq, tcpReq, err
69 }
70
71 func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
72 if !respHdr.Response {
73 return false
74 }
75 if reqID != respHdr.ID {
76 return false
77 }
78 if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
79 return false
80 }
81 return true
82 }
83
84 func dnsPacketRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
85 if _, err := c.Write(b); err != nil {
86 return dnsmessage.Parser{}, dnsmessage.Header{}, err
87 }
88
89 b = make([]byte, maxDNSPacketSize)
90 for {
91 n, err := c.Read(b)
92 if err != nil {
93 return dnsmessage.Parser{}, dnsmessage.Header{}, err
94 }
95 var p dnsmessage.Parser
96
97
98
99 h, err := p.Start(b[:n])
100 if err != nil {
101 continue
102 }
103 q, err := p.Question()
104 if err != nil || !checkResponse(id, query, h, q) {
105 continue
106 }
107 return p, h, nil
108 }
109 }
110
111 func dnsStreamRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
112 if _, err := c.Write(b); err != nil {
113 return dnsmessage.Parser{}, dnsmessage.Header{}, err
114 }
115
116 b = make([]byte, 1280)
117 if _, err := io.ReadFull(c, b[:2]); err != nil {
118 return dnsmessage.Parser{}, dnsmessage.Header{}, err
119 }
120 l := int(b[0])<<8 | int(b[1])
121 if l > len(b) {
122 b = make([]byte, l)
123 }
124 n, err := io.ReadFull(c, b[:l])
125 if err != nil {
126 return dnsmessage.Parser{}, dnsmessage.Header{}, err
127 }
128 var p dnsmessage.Parser
129 h, err := p.Start(b[:n])
130 if err != nil {
131 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
132 }
133 q, err := p.Question()
134 if err != nil {
135 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
136 }
137 if !checkResponse(id, query, h, q) {
138 return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
139 }
140 return p, h, nil
141 }
142
143
144 func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration, useTCP bool) (dnsmessage.Parser, dnsmessage.Header, error) {
145 q.Class = dnsmessage.ClassINET
146 id, udpReq, tcpReq, err := newRequest(q)
147 if err != nil {
148 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
149 }
150 var networks []string
151 if useTCP {
152 networks = []string{"tcp"}
153 } else {
154 networks = []string{"udp", "tcp"}
155 }
156 for _, network := range networks {
157 ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
158 defer cancel()
159
160 c, err := r.dial(ctx, network, server)
161 if err != nil {
162 return dnsmessage.Parser{}, dnsmessage.Header{}, err
163 }
164 if d, ok := ctx.Deadline(); ok && !d.IsZero() {
165 c.SetDeadline(d)
166 }
167 var p dnsmessage.Parser
168 var h dnsmessage.Header
169 if _, ok := c.(PacketConn); ok {
170 p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
171 } else {
172 p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
173 }
174 c.Close()
175 if err != nil {
176 return dnsmessage.Parser{}, dnsmessage.Header{}, mapErr(err)
177 }
178 if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
179 return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
180 }
181 if h.Truncated {
182 continue
183 }
184 return p, h, nil
185 }
186 return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
187 }
188
189
190 func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
191 if h.RCode == dnsmessage.RCodeNameError {
192 return errNoSuchHost
193 }
194
195 _, err := p.AnswerHeader()
196 if err != nil && err != dnsmessage.ErrSectionDone {
197 return errCannotUnmarshalDNSMessage
198 }
199
200
201
202 if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
203 return errLameReferral
204 }
205
206 if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
207
208
209
210
211
212 if h.RCode == dnsmessage.RCodeServerFailure {
213 return errServerTemporarilyMisbehaving
214 }
215 return errServerMisbehaving
216 }
217
218 return nil
219 }
220
221 func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
222 for {
223 h, err := p.AnswerHeader()
224 if err == dnsmessage.ErrSectionDone {
225 return errNoSuchHost
226 }
227 if err != nil {
228 return errCannotUnmarshalDNSMessage
229 }
230 if h.Type == qtype {
231 return nil
232 }
233 if err := p.SkipAnswer(); err != nil {
234 return errCannotUnmarshalDNSMessage
235 }
236 }
237 }
238
239
240
241 func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
242 var lastErr error
243 serverOffset := cfg.serverOffset()
244 sLen := uint32(len(cfg.servers))
245
246 n, err := dnsmessage.NewName(name)
247 if err != nil {
248 return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
249 }
250 q := dnsmessage.Question{
251 Name: n,
252 Type: qtype,
253 Class: dnsmessage.ClassINET,
254 }
255
256 for i := 0; i < cfg.attempts; i++ {
257 for j := uint32(0); j < sLen; j++ {
258 server := cfg.servers[(serverOffset+j)%sLen]
259
260 p, h, err := r.exchange(ctx, server, q, cfg.timeout, cfg.useTCP)
261 if err != nil {
262 dnsErr := &DNSError{
263 Err: err.Error(),
264 Name: name,
265 Server: server,
266 }
267 if nerr, ok := err.(Error); ok && nerr.Timeout() {
268 dnsErr.IsTimeout = true
269 }
270
271
272 if _, ok := err.(*OpError); ok {
273 dnsErr.IsTemporary = true
274 }
275 lastErr = dnsErr
276 continue
277 }
278
279 if err := checkHeader(&p, h); err != nil {
280 dnsErr := &DNSError{
281 Err: err.Error(),
282 Name: name,
283 Server: server,
284 }
285 if err == errServerTemporarilyMisbehaving {
286 dnsErr.IsTemporary = true
287 }
288 if err == errNoSuchHost {
289
290
291
292 dnsErr.IsNotFound = true
293 return p, server, dnsErr
294 }
295 lastErr = dnsErr
296 continue
297 }
298
299 err = skipToAnswer(&p, qtype)
300 if err == nil {
301 return p, server, nil
302 }
303 lastErr = &DNSError{
304 Err: err.Error(),
305 Name: name,
306 Server: server,
307 }
308 if err == errNoSuchHost {
309
310
311
312 lastErr.(*DNSError).IsNotFound = true
313 return p, server, lastErr
314 }
315 }
316 }
317 return dnsmessage.Parser{}, "", lastErr
318 }
319
320
321 type resolverConfig struct {
322 initOnce sync.Once
323
324
325
326 ch chan struct{}
327 lastChecked time.Time
328
329 mu sync.RWMutex
330 dnsConfig *dnsConfig
331 }
332
333 var resolvConf resolverConfig
334
335
336 func (conf *resolverConfig) init() {
337
338
339 conf.dnsConfig = systemConf().resolv
340 if conf.dnsConfig == nil {
341 conf.dnsConfig = dnsReadConfig("/etc/resolv.conf")
342 }
343 conf.lastChecked = time.Now()
344
345
346
347 conf.ch = make(chan struct{}, 1)
348 }
349
350
351
352
353 func (conf *resolverConfig) tryUpdate(name string) {
354 conf.initOnce.Do(conf.init)
355
356
357 if !conf.tryAcquireSema() {
358 return
359 }
360 defer conf.releaseSema()
361
362 now := time.Now()
363 if conf.lastChecked.After(now.Add(-5 * time.Second)) {
364 return
365 }
366 conf.lastChecked = now
367
368 var mtime time.Time
369 if fi, err := os.Stat(name); err == nil {
370 mtime = fi.ModTime()
371 }
372 if mtime.Equal(conf.dnsConfig.mtime) {
373 return
374 }
375
376 dnsConf := dnsReadConfig(name)
377 conf.mu.Lock()
378 conf.dnsConfig = dnsConf
379 conf.mu.Unlock()
380 }
381
382 func (conf *resolverConfig) tryAcquireSema() bool {
383 select {
384 case conf.ch <- struct{}{}:
385 return true
386 default:
387 return false
388 }
389 }
390
391 func (conf *resolverConfig) releaseSema() {
392 <-conf.ch
393 }
394
395 func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
396 if !isDomainName(name) {
397
398
399
400
401
402 return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
403 }
404 resolvConf.tryUpdate("/etc/resolv.conf")
405 resolvConf.mu.RLock()
406 conf := resolvConf.dnsConfig
407 resolvConf.mu.RUnlock()
408 var (
409 p dnsmessage.Parser
410 server string
411 err error
412 )
413 for _, fqdn := range conf.nameList(name) {
414 p, server, err = r.tryOneName(ctx, conf, fqdn, qtype)
415 if err == nil {
416 break
417 }
418 if nerr, ok := err.(Error); ok && nerr.Temporary() && r.strictErrors() {
419
420
421 break
422 }
423 }
424 if err == nil {
425 return p, server, nil
426 }
427 if err, ok := err.(*DNSError); ok {
428
429
430
431 err.Name = name
432 }
433 return dnsmessage.Parser{}, "", err
434 }
435
436
437
438
439
440 func avoidDNS(name string) bool {
441 if name == "" {
442 return true
443 }
444 if name[len(name)-1] == '.' {
445 name = name[:len(name)-1]
446 }
447 return stringsHasSuffixFold(name, ".onion")
448 }
449
450
451 func (conf *dnsConfig) nameList(name string) []string {
452 if avoidDNS(name) {
453 return nil
454 }
455
456
457 l := len(name)
458 rooted := l > 0 && name[l-1] == '.'
459 if l > 254 || l == 254 && rooted {
460 return nil
461 }
462
463
464 if rooted {
465 return []string{name}
466 }
467
468 hasNdots := count(name, '.') >= conf.ndots
469 name += "."
470 l++
471
472
473 names := make([]string, 0, 1+len(conf.search))
474
475 if hasNdots {
476 names = append(names, name)
477 }
478
479 for _, suffix := range conf.search {
480 if l+len(suffix) <= 254 {
481 names = append(names, name+suffix)
482 }
483 }
484
485 if !hasNdots {
486 names = append(names, name)
487 }
488 return names
489 }
490
491
492
493
494 type hostLookupOrder int
495
496 const (
497
498 hostLookupCgo hostLookupOrder = iota
499 hostLookupFilesDNS
500 hostLookupDNSFiles
501 hostLookupFiles
502 hostLookupDNS
503 )
504
505 var lookupOrderName = map[hostLookupOrder]string{
506 hostLookupCgo: "cgo",
507 hostLookupFilesDNS: "files,dns",
508 hostLookupDNSFiles: "dns,files",
509 hostLookupFiles: "files",
510 hostLookupDNS: "dns",
511 }
512
513 func (o hostLookupOrder) String() string {
514 if s, ok := lookupOrderName[o]; ok {
515 return s
516 }
517 return "hostLookupOrder=" + itoa.Itoa(int(o)) + "??"
518 }
519
520
521
522
523
524
525
526 func (r *Resolver) goLookupHost(ctx context.Context, name string) (addrs []string, err error) {
527 return r.goLookupHostOrder(ctx, name, hostLookupFilesDNS)
528 }
529
530 func (r *Resolver) goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []string, err error) {
531 if order == hostLookupFilesDNS || order == hostLookupFiles {
532
533 addrs = lookupStaticHost(name)
534 if len(addrs) > 0 || order == hostLookupFiles {
535 return
536 }
537 }
538 ips, _, err := r.goLookupIPCNAMEOrder(ctx, "ip", name, order)
539 if err != nil {
540 return
541 }
542 addrs = make([]string, 0, len(ips))
543 for _, ip := range ips {
544 addrs = append(addrs, ip.String())
545 }
546 return
547 }
548
549
550 func goLookupIPFiles(name string) (addrs []IPAddr) {
551 for _, haddr := range lookupStaticHost(name) {
552 haddr, zone := splitHostZone(haddr)
553 if ip := ParseIP(haddr); ip != nil {
554 addr := IPAddr{IP: ip, Zone: zone}
555 addrs = append(addrs, addr)
556 }
557 }
558 sortByRFC6724(addrs)
559 return
560 }
561
562
563
564 func (r *Resolver) goLookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
565 order := systemConf().hostLookupOrder(r, host)
566 addrs, _, err = r.goLookupIPCNAMEOrder(ctx, network, host, order)
567 return
568 }
569
570 func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name string, order hostLookupOrder) (addrs []IPAddr, cname dnsmessage.Name, err error) {
571 if order == hostLookupFilesDNS || order == hostLookupFiles {
572 addrs = goLookupIPFiles(name)
573 if len(addrs) > 0 || order == hostLookupFiles {
574 return addrs, dnsmessage.Name{}, nil
575 }
576 }
577 if !isDomainName(name) {
578
579 return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
580 }
581 resolvConf.tryUpdate("/etc/resolv.conf")
582 resolvConf.mu.RLock()
583 conf := resolvConf.dnsConfig
584 resolvConf.mu.RUnlock()
585 type result struct {
586 p dnsmessage.Parser
587 server string
588 error
589 }
590 lane := make(chan result, 1)
591 qtypes := []dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA}
592 switch ipVersion(network) {
593 case '4':
594 qtypes = []dnsmessage.Type{dnsmessage.TypeA}
595 case '6':
596 qtypes = []dnsmessage.Type{dnsmessage.TypeAAAA}
597 }
598 var queryFn func(fqdn string, qtype dnsmessage.Type)
599 var responseFn func(fqdn string, qtype dnsmessage.Type) result
600 if conf.singleRequest {
601 queryFn = func(fqdn string, qtype dnsmessage.Type) {}
602 responseFn = func(fqdn string, qtype dnsmessage.Type) result {
603 dnsWaitGroup.Add(1)
604 defer dnsWaitGroup.Done()
605 p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
606 return result{p, server, err}
607 }
608 } else {
609 queryFn = func(fqdn string, qtype dnsmessage.Type) {
610 dnsWaitGroup.Add(1)
611 go func(qtype dnsmessage.Type) {
612 p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
613 lane <- result{p, server, err}
614 dnsWaitGroup.Done()
615 }(qtype)
616 }
617 responseFn = func(fqdn string, qtype dnsmessage.Type) result {
618 return <-lane
619 }
620 }
621 var lastErr error
622 for _, fqdn := range conf.nameList(name) {
623 for _, qtype := range qtypes {
624 queryFn(fqdn, qtype)
625 }
626 hitStrictError := false
627 for _, qtype := range qtypes {
628 result := responseFn(fqdn, qtype)
629 if result.error != nil {
630 if nerr, ok := result.error.(Error); ok && nerr.Temporary() && r.strictErrors() {
631
632 hitStrictError = true
633 lastErr = result.error
634 } else if lastErr == nil || fqdn == name+"." {
635
636 lastErr = result.error
637 }
638 continue
639 }
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656 loop:
657 for {
658 h, err := result.p.AnswerHeader()
659 if err != nil && err != dnsmessage.ErrSectionDone {
660 lastErr = &DNSError{
661 Err: "cannot marshal DNS message",
662 Name: name,
663 Server: result.server,
664 }
665 }
666 if err != nil {
667 break
668 }
669 switch h.Type {
670 case dnsmessage.TypeA:
671 a, err := result.p.AResource()
672 if err != nil {
673 lastErr = &DNSError{
674 Err: "cannot marshal DNS message",
675 Name: name,
676 Server: result.server,
677 }
678 break loop
679 }
680 addrs = append(addrs, IPAddr{IP: IP(a.A[:])})
681
682 case dnsmessage.TypeAAAA:
683 aaaa, err := result.p.AAAAResource()
684 if err != nil {
685 lastErr = &DNSError{
686 Err: "cannot marshal DNS message",
687 Name: name,
688 Server: result.server,
689 }
690 break loop
691 }
692 addrs = append(addrs, IPAddr{IP: IP(aaaa.AAAA[:])})
693
694 default:
695 if err := result.p.SkipAnswer(); err != nil {
696 lastErr = &DNSError{
697 Err: "cannot marshal DNS message",
698 Name: name,
699 Server: result.server,
700 }
701 break loop
702 }
703 continue
704 }
705 if cname.Length == 0 && h.Name.Length != 0 {
706 cname = h.Name
707 }
708 }
709 }
710 if hitStrictError {
711
712
713
714 addrs = nil
715 break
716 }
717 if len(addrs) > 0 {
718 break
719 }
720 }
721 if lastErr, ok := lastErr.(*DNSError); ok {
722
723
724
725 lastErr.Name = name
726 }
727 sortByRFC6724(addrs)
728 if len(addrs) == 0 {
729 if order == hostLookupDNSFiles {
730 addrs = goLookupIPFiles(name)
731 }
732 if len(addrs) == 0 && lastErr != nil {
733 return nil, dnsmessage.Name{}, lastErr
734 }
735 }
736 return addrs, cname, nil
737 }
738
739
740 func (r *Resolver) goLookupCNAME(ctx context.Context, host string) (string, error) {
741 order := systemConf().hostLookupOrder(r, host)
742 _, cname, err := r.goLookupIPCNAMEOrder(ctx, "ip", host, order)
743 return cname.String(), err
744 }
745
746
747
748
749
750
751 func (r *Resolver) goLookupPTR(ctx context.Context, addr string) ([]string, error) {
752 names := lookupStaticAddr(addr)
753 if len(names) > 0 {
754 return names, nil
755 }
756 arpa, err := reverseaddr(addr)
757 if err != nil {
758 return nil, err
759 }
760 p, server, err := r.lookup(ctx, arpa, dnsmessage.TypePTR)
761 if err != nil {
762 return nil, err
763 }
764 var ptrs []string
765 for {
766 h, err := p.AnswerHeader()
767 if err == dnsmessage.ErrSectionDone {
768 break
769 }
770 if err != nil {
771 return nil, &DNSError{
772 Err: "cannot marshal DNS message",
773 Name: addr,
774 Server: server,
775 }
776 }
777 if h.Type != dnsmessage.TypePTR {
778 err := p.SkipAnswer()
779 if err != nil {
780 return nil, &DNSError{
781 Err: "cannot marshal DNS message",
782 Name: addr,
783 Server: server,
784 }
785 }
786 continue
787 }
788 ptr, err := p.PTRResource()
789 if err != nil {
790 return nil, &DNSError{
791 Err: "cannot marshal DNS message",
792 Name: addr,
793 Server: server,
794 }
795 }
796 ptrs = append(ptrs, ptr.PTR.String())
797
798 }
799 return ptrs, nil
800 }
801
View as plain text