Source file
src/net/http/transport_test.go
1
2
3
4
5
6
7
8
9
10 package http_test
11
12 import (
13 "bufio"
14 "bytes"
15 "compress/gzip"
16 "context"
17 "crypto/rand"
18 "crypto/tls"
19 "crypto/x509"
20 "encoding/binary"
21 "errors"
22 "fmt"
23 "go/token"
24 "internal/nettrace"
25 "io"
26 "log"
27 mrand "math/rand"
28 "net"
29 . "net/http"
30 "net/http/httptest"
31 "net/http/httptrace"
32 "net/http/httputil"
33 "net/http/internal/testcert"
34 "net/textproto"
35 "net/url"
36 "os"
37 "reflect"
38 "runtime"
39 "strconv"
40 "strings"
41 "sync"
42 "sync/atomic"
43 "testing"
44 "testing/iotest"
45 "time"
46
47 "golang.org/x/net/http/httpguts"
48 )
49
50
51
52
53
54 var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
55 if r.FormValue("close") == "true" {
56 w.Header().Set("Connection", "close")
57 }
58 w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
59 w.Write([]byte(r.RemoteAddr))
60 })
61
62
63 type testCloseConn struct {
64 net.Conn
65 set *testConnSet
66 }
67
68 func (c *testCloseConn) Close() error {
69 c.set.remove(c)
70 return c.Conn.Close()
71 }
72
73
74
75 type testConnSet struct {
76 t *testing.T
77 mu sync.Mutex
78 closed map[net.Conn]bool
79 list []net.Conn
80 }
81
82 func (tcs *testConnSet) insert(c net.Conn) {
83 tcs.mu.Lock()
84 defer tcs.mu.Unlock()
85 tcs.closed[c] = false
86 tcs.list = append(tcs.list, c)
87 }
88
89 func (tcs *testConnSet) remove(c net.Conn) {
90 tcs.mu.Lock()
91 defer tcs.mu.Unlock()
92 tcs.closed[c] = true
93 }
94
95
96 func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
97 connSet := &testConnSet{
98 t: t,
99 closed: make(map[net.Conn]bool),
100 }
101 dial := func(n, addr string) (net.Conn, error) {
102 c, err := net.Dial(n, addr)
103 if err != nil {
104 return nil, err
105 }
106 tc := &testCloseConn{c, connSet}
107 connSet.insert(tc)
108 return tc, nil
109 }
110 return connSet, dial
111 }
112
113 func (tcs *testConnSet) check(t *testing.T) {
114 tcs.mu.Lock()
115 defer tcs.mu.Unlock()
116 for i := 4; i >= 0; i-- {
117 for i, c := range tcs.list {
118 if tcs.closed[c] {
119 continue
120 }
121 if i != 0 {
122 tcs.mu.Unlock()
123 time.Sleep(50 * time.Millisecond)
124 tcs.mu.Lock()
125 continue
126 }
127 t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
128 }
129 }
130 }
131
132 func TestReuseRequest(t *testing.T) {
133 defer afterTest(t)
134 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
135 w.Write([]byte("{}"))
136 }))
137 defer ts.Close()
138
139 c := ts.Client()
140 req, _ := NewRequest("GET", ts.URL, nil)
141 res, err := c.Do(req)
142 if err != nil {
143 t.Fatal(err)
144 }
145 err = res.Body.Close()
146 if err != nil {
147 t.Fatal(err)
148 }
149
150 res, err = c.Do(req)
151 if err != nil {
152 t.Fatal(err)
153 }
154 err = res.Body.Close()
155 if err != nil {
156 t.Fatal(err)
157 }
158 }
159
160
161
162 func TestTransportKeepAlives(t *testing.T) {
163 defer afterTest(t)
164 ts := httptest.NewServer(hostPortHandler)
165 defer ts.Close()
166
167 c := ts.Client()
168 for _, disableKeepAlive := range []bool{false, true} {
169 c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
170 fetch := func(n int) string {
171 res, err := c.Get(ts.URL)
172 if err != nil {
173 t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
174 }
175 body, err := io.ReadAll(res.Body)
176 if err != nil {
177 t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
178 }
179 return string(body)
180 }
181
182 body1 := fetch(1)
183 body2 := fetch(2)
184
185 bodiesDiffer := body1 != body2
186 if bodiesDiffer != disableKeepAlive {
187 t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
188 disableKeepAlive, bodiesDiffer, body1, body2)
189 }
190 }
191 }
192
193 func TestTransportConnectionCloseOnResponse(t *testing.T) {
194 defer afterTest(t)
195 ts := httptest.NewServer(hostPortHandler)
196 defer ts.Close()
197
198 connSet, testDial := makeTestDial(t)
199
200 c := ts.Client()
201 tr := c.Transport.(*Transport)
202 tr.Dial = testDial
203
204 for _, connectionClose := range []bool{false, true} {
205 fetch := func(n int) string {
206 req := new(Request)
207 var err error
208 req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
209 if err != nil {
210 t.Fatalf("URL parse error: %v", err)
211 }
212 req.Method = "GET"
213 req.Proto = "HTTP/1.1"
214 req.ProtoMajor = 1
215 req.ProtoMinor = 1
216
217 res, err := c.Do(req)
218 if err != nil {
219 t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
220 }
221 defer res.Body.Close()
222 body, err := io.ReadAll(res.Body)
223 if err != nil {
224 t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
225 }
226 return string(body)
227 }
228
229 body1 := fetch(1)
230 body2 := fetch(2)
231 bodiesDiffer := body1 != body2
232 if bodiesDiffer != connectionClose {
233 t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
234 connectionClose, bodiesDiffer, body1, body2)
235 }
236
237 tr.CloseIdleConnections()
238 }
239
240 connSet.check(t)
241 }
242
243 func TestTransportConnectionCloseOnRequest(t *testing.T) {
244 defer afterTest(t)
245 ts := httptest.NewServer(hostPortHandler)
246 defer ts.Close()
247
248 connSet, testDial := makeTestDial(t)
249
250 c := ts.Client()
251 tr := c.Transport.(*Transport)
252 tr.Dial = testDial
253 for _, connectionClose := range []bool{false, true} {
254 fetch := func(n int) string {
255 req := new(Request)
256 var err error
257 req.URL, err = url.Parse(ts.URL)
258 if err != nil {
259 t.Fatalf("URL parse error: %v", err)
260 }
261 req.Method = "GET"
262 req.Proto = "HTTP/1.1"
263 req.ProtoMajor = 1
264 req.ProtoMinor = 1
265 req.Close = connectionClose
266
267 res, err := c.Do(req)
268 if err != nil {
269 t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
270 }
271 if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(connectionClose); got != want {
272 t.Errorf("For connectionClose = %v; handler's X-Saw-Close was %v; want %v",
273 connectionClose, got, !connectionClose)
274 }
275 body, err := io.ReadAll(res.Body)
276 if err != nil {
277 t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
278 }
279 return string(body)
280 }
281
282 body1 := fetch(1)
283 body2 := fetch(2)
284 bodiesDiffer := body1 != body2
285 if bodiesDiffer != connectionClose {
286 t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
287 connectionClose, bodiesDiffer, body1, body2)
288 }
289
290 tr.CloseIdleConnections()
291 }
292
293 connSet.check(t)
294 }
295
296
297
298
299 func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
300 defer afterTest(t)
301 ts := httptest.NewServer(hostPortHandler)
302 defer ts.Close()
303
304 c := ts.Client()
305 c.Transport.(*Transport).DisableKeepAlives = true
306
307 res, err := c.Get(ts.URL)
308 if err != nil {
309 t.Fatal(err)
310 }
311 res.Body.Close()
312 if res.Header.Get("X-Saw-Close") != "true" {
313 t.Errorf("handler didn't see Connection: close ")
314 }
315 }
316
317
318
319 func TestTransportRespectRequestWantsClose(t *testing.T) {
320 tests := []struct {
321 disableKeepAlives bool
322 close bool
323 }{
324 {disableKeepAlives: false, close: false},
325 {disableKeepAlives: false, close: true},
326 {disableKeepAlives: true, close: false},
327 {disableKeepAlives: true, close: true},
328 }
329
330 for _, tc := range tests {
331 t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
332 func(t *testing.T) {
333 defer afterTest(t)
334 ts := httptest.NewServer(hostPortHandler)
335 defer ts.Close()
336
337 c := ts.Client()
338 c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
339 req, err := NewRequest("GET", ts.URL, nil)
340 if err != nil {
341 t.Fatal(err)
342 }
343 count := 0
344 trace := &httptrace.ClientTrace{
345 WroteHeaderField: func(key string, field []string) {
346 if key != "Connection" {
347 return
348 }
349 if httpguts.HeaderValuesContainsToken(field, "close") {
350 count += 1
351 }
352 },
353 }
354 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
355 req.Close = tc.close
356 res, err := c.Do(req)
357 if err != nil {
358 t.Fatal(err)
359 }
360 defer res.Body.Close()
361 if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
362 t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
363 }
364 })
365 }
366
367 }
368
369 func TestTransportIdleCacheKeys(t *testing.T) {
370 defer afterTest(t)
371 ts := httptest.NewServer(hostPortHandler)
372 defer ts.Close()
373 c := ts.Client()
374 tr := c.Transport.(*Transport)
375
376 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
377 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
378 }
379
380 resp, err := c.Get(ts.URL)
381 if err != nil {
382 t.Error(err)
383 }
384 io.ReadAll(resp.Body)
385
386 keys := tr.IdleConnKeysForTesting()
387 if e, g := 1, len(keys); e != g {
388 t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
389 }
390
391 if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
392 t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
393 }
394
395 tr.CloseIdleConnections()
396 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
397 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
398 }
399 }
400
401
402
403 func TestTransportReadToEndReusesConn(t *testing.T) {
404 defer afterTest(t)
405 const msg = "foobar"
406
407 var addrSeen map[string]int
408 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
409 addrSeen[r.RemoteAddr]++
410 if r.URL.Path == "/chunked/" {
411 w.WriteHeader(200)
412 w.(Flusher).Flush()
413 } else {
414 w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
415 w.WriteHeader(200)
416 }
417 w.Write([]byte(msg))
418 }))
419 defer ts.Close()
420
421 buf := make([]byte, len(msg))
422
423 for pi, path := range []string{"/content-length/", "/chunked/"} {
424 wantLen := []int{len(msg), -1}[pi]
425 addrSeen = make(map[string]int)
426 for i := 0; i < 3; i++ {
427 res, err := Get(ts.URL + path)
428 if err != nil {
429 t.Errorf("Get %s: %v", path, err)
430 continue
431 }
432
433
434
435
436
437 defer res.Body.Close()
438
439 if res.ContentLength != int64(wantLen) {
440 t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
441 }
442 n, err := res.Body.Read(buf)
443 if n != len(msg) || err != io.EOF {
444 t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg))
445 }
446 }
447 if len(addrSeen) != 1 {
448 t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
449 }
450 }
451 }
452
453 func TestTransportMaxPerHostIdleConns(t *testing.T) {
454 defer afterTest(t)
455 stop := make(chan struct{})
456 defer close(stop)
457
458 resch := make(chan string)
459 gotReq := make(chan bool)
460 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
461 gotReq <- true
462 var msg string
463 select {
464 case <-stop:
465 return
466 case msg = <-resch:
467 }
468 _, err := w.Write([]byte(msg))
469 if err != nil {
470 t.Errorf("Write: %v", err)
471 return
472 }
473 }))
474 defer ts.Close()
475
476 c := ts.Client()
477 tr := c.Transport.(*Transport)
478 maxIdleConnsPerHost := 2
479 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
480
481
482
483 donech := make(chan bool)
484 doReq := func() {
485 defer func() {
486 select {
487 case <-stop:
488 return
489 case donech <- t.Failed():
490 }
491 }()
492 resp, err := c.Get(ts.URL)
493 if err != nil {
494 t.Error(err)
495 return
496 }
497 if _, err := io.ReadAll(resp.Body); err != nil {
498 t.Errorf("ReadAll: %v", err)
499 return
500 }
501 }
502 go doReq()
503 <-gotReq
504 go doReq()
505 <-gotReq
506 go doReq()
507 <-gotReq
508
509 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
510 t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
511 }
512
513 resch <- "res1"
514 <-donech
515 keys := tr.IdleConnKeysForTesting()
516 if e, g := 1, len(keys); e != g {
517 t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
518 }
519 addr := ts.Listener.Addr().String()
520 cacheKey := "|http|" + addr
521 if keys[0] != cacheKey {
522 t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
523 }
524 if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
525 t.Errorf("after first response, expected %d idle conns; got %d", e, g)
526 }
527
528 resch <- "res2"
529 <-donech
530 if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
531 t.Errorf("after second response, idle conns = %d; want %d", g, w)
532 }
533
534 resch <- "res3"
535 <-donech
536 if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
537 t.Errorf("after third response, idle conns = %d; want %d", g, w)
538 }
539 }
540
541 func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
542 defer afterTest(t)
543 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
544 _, err := w.Write([]byte("foo"))
545 if err != nil {
546 t.Fatalf("Write: %v", err)
547 }
548 }))
549 defer ts.Close()
550 c := ts.Client()
551 tr := c.Transport.(*Transport)
552 dialStarted := make(chan struct{})
553 stallDial := make(chan struct{})
554 tr.Dial = func(network, addr string) (net.Conn, error) {
555 dialStarted <- struct{}{}
556 <-stallDial
557 return net.Dial(network, addr)
558 }
559
560 tr.DisableKeepAlives = true
561 tr.MaxConnsPerHost = 1
562
563 preDial := make(chan struct{})
564 reqComplete := make(chan struct{})
565 doReq := func(reqId string) {
566 req, _ := NewRequest("GET", ts.URL, nil)
567 trace := &httptrace.ClientTrace{
568 GetConn: func(hostPort string) {
569 preDial <- struct{}{}
570 },
571 }
572 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
573 resp, err := tr.RoundTrip(req)
574 if err != nil {
575 t.Errorf("unexpected error for request %s: %v", reqId, err)
576 }
577 _, err = io.ReadAll(resp.Body)
578 if err != nil {
579 t.Errorf("unexpected error for request %s: %v", reqId, err)
580 }
581 reqComplete <- struct{}{}
582 }
583
584 go doReq("req1")
585 <-preDial
586 <-dialStarted
587
588
589 go doReq("req2")
590 <-preDial
591 select {
592 case <-dialStarted:
593 t.Error("req2 dial started while req1 dial in progress")
594 return
595 default:
596 }
597
598
599 stallDial <- struct{}{}
600 <-reqComplete
601
602
603 <-dialStarted
604 stallDial <- struct{}{}
605 <-reqComplete
606 }
607
608 func TestTransportMaxConnsPerHost(t *testing.T) {
609 defer afterTest(t)
610 CondSkipHTTP2(t)
611
612 h := HandlerFunc(func(w ResponseWriter, r *Request) {
613 _, err := w.Write([]byte("foo"))
614 if err != nil {
615 t.Fatalf("Write: %v", err)
616 }
617 })
618
619 testMaxConns := func(scheme string, ts *httptest.Server) {
620 defer ts.Close()
621
622 c := ts.Client()
623 tr := c.Transport.(*Transport)
624 tr.MaxConnsPerHost = 1
625 if err := ExportHttp2ConfigureTransport(tr); err != nil {
626 t.Fatalf("ExportHttp2ConfigureTransport: %v", err)
627 }
628
629 mu := sync.Mutex{}
630 var conns []net.Conn
631 var dialCnt, gotConnCnt, tlsHandshakeCnt int32
632 tr.Dial = func(network, addr string) (net.Conn, error) {
633 atomic.AddInt32(&dialCnt, 1)
634 c, err := net.Dial(network, addr)
635 mu.Lock()
636 defer mu.Unlock()
637 conns = append(conns, c)
638 return c, err
639 }
640
641 doReq := func() {
642 trace := &httptrace.ClientTrace{
643 GotConn: func(connInfo httptrace.GotConnInfo) {
644 if !connInfo.Reused {
645 atomic.AddInt32(&gotConnCnt, 1)
646 }
647 },
648 TLSHandshakeStart: func() {
649 atomic.AddInt32(&tlsHandshakeCnt, 1)
650 },
651 }
652 req, _ := NewRequest("GET", ts.URL, nil)
653 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
654
655 resp, err := c.Do(req)
656 if err != nil {
657 t.Fatalf("request failed: %v", err)
658 }
659 defer resp.Body.Close()
660 _, err = io.ReadAll(resp.Body)
661 if err != nil {
662 t.Fatalf("read body failed: %v", err)
663 }
664 }
665
666 wg := sync.WaitGroup{}
667 for i := 0; i < 10; i++ {
668 wg.Add(1)
669 go func() {
670 defer wg.Done()
671 doReq()
672 }()
673 }
674 wg.Wait()
675
676 expected := int32(tr.MaxConnsPerHost)
677 if dialCnt != expected {
678 t.Errorf("round 1: too many dials (%s): %d != %d", scheme, dialCnt, expected)
679 }
680 if gotConnCnt != expected {
681 t.Errorf("round 1: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected)
682 }
683 if ts.TLS != nil && tlsHandshakeCnt != expected {
684 t.Errorf("round 1: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected)
685 }
686
687 if t.Failed() {
688 t.FailNow()
689 }
690
691 mu.Lock()
692 for _, c := range conns {
693 c.Close()
694 }
695 conns = nil
696 mu.Unlock()
697 tr.CloseIdleConnections()
698
699 doReq()
700 expected++
701 if dialCnt != expected {
702 t.Errorf("round 2: too many dials (%s): %d", scheme, dialCnt)
703 }
704 if gotConnCnt != expected {
705 t.Errorf("round 2: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected)
706 }
707 if ts.TLS != nil && tlsHandshakeCnt != expected {
708 t.Errorf("round 2: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected)
709 }
710 }
711
712 testMaxConns("http", httptest.NewServer(h))
713 testMaxConns("https", httptest.NewTLSServer(h))
714
715 ts := httptest.NewUnstartedServer(h)
716 ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
717 ts.StartTLS()
718 testMaxConns("http2", ts)
719 }
720
721 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
722 setParallel(t)
723 defer afterTest(t)
724 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
725 io.WriteString(w, r.RemoteAddr)
726 }))
727 defer ts.Close()
728
729 c := ts.Client()
730 tr := c.Transport.(*Transport)
731
732 doReq := func(name string) string {
733
734
735 res, err := c.Post(ts.URL, "", nil)
736 if err != nil {
737 t.Fatalf("%s: %v", name, err)
738 }
739 if res.StatusCode != 200 {
740 t.Fatalf("%s: %v", name, res.Status)
741 }
742 defer res.Body.Close()
743 slurp, err := io.ReadAll(res.Body)
744 if err != nil {
745 t.Fatalf("%s: %v", name, err)
746 }
747 return string(slurp)
748 }
749
750 first := doReq("first")
751 keys1 := tr.IdleConnKeysForTesting()
752
753 ts.CloseClientConnections()
754
755 var keys2 []string
756 if !waitCondition(3*time.Second, 50*time.Millisecond, func() bool {
757 keys2 = tr.IdleConnKeysForTesting()
758 return len(keys2) == 0
759 }) {
760 t.Fatalf("Transport didn't notice idle connection's death.\nbefore: %q\n after: %q\n", keys1, keys2)
761 }
762
763 second := doReq("second")
764 if first == second {
765 t.Errorf("expected a different connection between requests. got %q both times", first)
766 }
767 }
768
769
770
771 func TestTransportServerClosingUnexpectedly(t *testing.T) {
772 setParallel(t)
773 defer afterTest(t)
774 ts := httptest.NewServer(hostPortHandler)
775 defer ts.Close()
776 c := ts.Client()
777
778 fetch := func(n, retries int) string {
779 condFatalf := func(format string, arg ...any) {
780 if retries <= 0 {
781 t.Fatalf(format, arg...)
782 }
783 t.Logf("retrying shortly after expected error: "+format, arg...)
784 time.Sleep(time.Second / time.Duration(retries))
785 }
786 for retries >= 0 {
787 retries--
788 res, err := c.Get(ts.URL)
789 if err != nil {
790 condFatalf("error in req #%d, GET: %v", n, err)
791 continue
792 }
793 body, err := io.ReadAll(res.Body)
794 if err != nil {
795 condFatalf("error in req #%d, ReadAll: %v", n, err)
796 continue
797 }
798 res.Body.Close()
799 return string(body)
800 }
801 panic("unreachable")
802 }
803
804 body1 := fetch(1, 0)
805 body2 := fetch(2, 0)
806
807
808
809
810
811
812
813
814 ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
815
816 body3 := fetch(3, 5)
817
818 if body1 != body2 {
819 t.Errorf("expected body1 and body2 to be equal")
820 }
821 if body2 == body3 {
822 t.Errorf("expected body2 and body3 to be different")
823 }
824 }
825
826
827
828 func TestStressSurpriseServerCloses(t *testing.T) {
829 defer afterTest(t)
830 if testing.Short() {
831 t.Skip("skipping test in short mode")
832 }
833 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
834 w.Header().Set("Content-Length", "5")
835 w.Header().Set("Content-Type", "text/plain")
836 w.Write([]byte("Hello"))
837 w.(Flusher).Flush()
838 conn, buf, _ := w.(Hijacker).Hijack()
839 buf.Flush()
840 conn.Close()
841 }))
842 defer ts.Close()
843 c := ts.Client()
844
845
846
847
848
849
850
851 const (
852 numClients = 20
853 reqsPerClient = 25
854 )
855 activityc := make(chan bool)
856 for i := 0; i < numClients; i++ {
857 go func() {
858 for i := 0; i < reqsPerClient; i++ {
859 res, err := c.Get(ts.URL)
860 if err == nil {
861
862
863
864
865
866
867 res.Body.Close()
868 }
869 if !<-activityc {
870 return
871 }
872 }
873 }()
874 }
875
876
877 for i := 0; i < numClients*reqsPerClient; i++ {
878 select {
879 case activityc <- true:
880 case <-time.After(5 * time.Second):
881 close(activityc)
882 t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile")
883 }
884 }
885 }
886
887
888
889 func TestTransportHeadResponses(t *testing.T) {
890 defer afterTest(t)
891 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
892 if r.Method != "HEAD" {
893 panic("expected HEAD; got " + r.Method)
894 }
895 w.Header().Set("Content-Length", "123")
896 w.WriteHeader(200)
897 }))
898 defer ts.Close()
899 c := ts.Client()
900
901 for i := 0; i < 2; i++ {
902 res, err := c.Head(ts.URL)
903 if err != nil {
904 t.Errorf("error on loop %d: %v", i, err)
905 continue
906 }
907 if e, g := "123", res.Header.Get("Content-Length"); e != g {
908 t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
909 }
910 if e, g := int64(123), res.ContentLength; e != g {
911 t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
912 }
913 if all, err := io.ReadAll(res.Body); err != nil {
914 t.Errorf("loop %d: Body ReadAll: %v", i, err)
915 } else if len(all) != 0 {
916 t.Errorf("Bogus body %q", all)
917 }
918 }
919 }
920
921
922
923 func TestTransportHeadChunkedResponse(t *testing.T) {
924 defer afterTest(t)
925 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
926 if r.Method != "HEAD" {
927 panic("expected HEAD; got " + r.Method)
928 }
929 w.Header().Set("Transfer-Encoding", "chunked")
930 w.Header().Set("x-client-ipport", r.RemoteAddr)
931 w.WriteHeader(200)
932 }))
933 defer ts.Close()
934 c := ts.Client()
935
936
937
938 didRead := make(chan bool)
939 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
940 defer SetReadLoopBeforeNextReadHook(nil)
941
942 res1, err := c.Head(ts.URL)
943 <-didRead
944
945 if err != nil {
946 t.Fatalf("request 1 error: %v", err)
947 }
948
949 res2, err := c.Head(ts.URL)
950 <-didRead
951
952 if err != nil {
953 t.Fatalf("request 2 error: %v", err)
954 }
955 if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
956 t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
957 }
958 }
959
960 var roundTripTests = []struct {
961 accept string
962 expectAccept string
963 compressed bool
964 }{
965
966 {"", "gzip", false},
967
968 {"foo", "foo", false},
969
970 {"gzip", "gzip", true},
971 }
972
973
974 func TestRoundTripGzip(t *testing.T) {
975 setParallel(t)
976 defer afterTest(t)
977 const responseBody = "test response body"
978 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
979 accept := req.Header.Get("Accept-Encoding")
980 if expect := req.FormValue("expect_accept"); accept != expect {
981 t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
982 req.FormValue("testnum"), accept, expect)
983 }
984 if accept == "gzip" {
985 rw.Header().Set("Content-Encoding", "gzip")
986 gz := gzip.NewWriter(rw)
987 gz.Write([]byte(responseBody))
988 gz.Close()
989 } else {
990 rw.Header().Set("Content-Encoding", accept)
991 rw.Write([]byte(responseBody))
992 }
993 }))
994 defer ts.Close()
995 tr := ts.Client().Transport.(*Transport)
996
997 for i, test := range roundTripTests {
998
999 req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
1000 if test.accept != "" {
1001 req.Header.Set("Accept-Encoding", test.accept)
1002 }
1003 res, err := tr.RoundTrip(req)
1004 if err != nil {
1005 t.Errorf("%d. RoundTrip: %v", i, err)
1006 continue
1007 }
1008 var body []byte
1009 if test.compressed {
1010 var r *gzip.Reader
1011 r, err = gzip.NewReader(res.Body)
1012 if err != nil {
1013 t.Errorf("%d. gzip NewReader: %v", i, err)
1014 continue
1015 }
1016 body, err = io.ReadAll(r)
1017 res.Body.Close()
1018 } else {
1019 body, err = io.ReadAll(res.Body)
1020 }
1021 if err != nil {
1022 t.Errorf("%d. Error: %q", i, err)
1023 continue
1024 }
1025 if g, e := string(body), responseBody; g != e {
1026 t.Errorf("%d. body = %q; want %q", i, g, e)
1027 }
1028 if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
1029 t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
1030 }
1031 if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
1032 t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
1033 }
1034 }
1035
1036 }
1037
1038 func TestTransportGzip(t *testing.T) {
1039 setParallel(t)
1040 defer afterTest(t)
1041 const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
1042 const nRandBytes = 1024 * 1024
1043 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
1044 if req.Method == "HEAD" {
1045 if g := req.Header.Get("Accept-Encoding"); g != "" {
1046 t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
1047 }
1048 return
1049 }
1050 if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
1051 t.Errorf("Accept-Encoding = %q, want %q", g, e)
1052 }
1053 rw.Header().Set("Content-Encoding", "gzip")
1054
1055 var w io.Writer = rw
1056 var buf bytes.Buffer
1057 if req.FormValue("chunked") == "0" {
1058 w = &buf
1059 defer io.Copy(rw, &buf)
1060 defer func() {
1061 rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
1062 }()
1063 }
1064 gz := gzip.NewWriter(w)
1065 gz.Write([]byte(testString))
1066 if req.FormValue("body") == "large" {
1067 io.CopyN(gz, rand.Reader, nRandBytes)
1068 }
1069 gz.Close()
1070 }))
1071 defer ts.Close()
1072 c := ts.Client()
1073
1074 for _, chunked := range []string{"1", "0"} {
1075
1076 res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
1077 if err != nil {
1078 t.Fatalf("large get: %v", err)
1079 }
1080 buf := make([]byte, len(testString))
1081 n, err := io.ReadFull(res.Body, buf)
1082 if err != nil {
1083 t.Fatalf("partial read of large response: size=%d, %v", n, err)
1084 }
1085 if e, g := testString, string(buf); e != g {
1086 t.Errorf("partial read got %q, expected %q", g, e)
1087 }
1088 res.Body.Close()
1089
1090 n, err = res.Body.Read(buf)
1091 if n != 0 || err == nil {
1092 t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
1093 }
1094
1095
1096 res, err = c.Get(ts.URL + "/?chunked=" + chunked)
1097 if err != nil {
1098 t.Fatal(err)
1099 }
1100 body, err := io.ReadAll(res.Body)
1101 if err != nil {
1102 t.Fatal(err)
1103 }
1104 if g, e := string(body), testString; g != e {
1105 t.Fatalf("body = %q; want %q", g, e)
1106 }
1107 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1108 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1109 }
1110
1111
1112 n, err = res.Body.Read(buf)
1113 if n != 0 || err == nil {
1114 t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
1115 }
1116 res.Body.Close()
1117 n, err = res.Body.Read(buf)
1118 if n != 0 || err == nil {
1119 t.Errorf("expected Read error after Close; got %d, %v", n, err)
1120 }
1121 }
1122
1123
1124 res, err := c.Head(ts.URL)
1125 if err != nil {
1126 t.Fatalf("Head: %v", err)
1127 }
1128 if res.StatusCode != 200 {
1129 t.Errorf("Head status=%d; want=200", res.StatusCode)
1130 }
1131 }
1132
1133
1134
1135 func TestTransportExpect100Continue(t *testing.T) {
1136 setParallel(t)
1137 defer afterTest(t)
1138
1139 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
1140 switch req.URL.Path {
1141 case "/100":
1142
1143 if _, err := io.Copy(io.Discard, req.Body); err != nil {
1144 t.Error("Failed to read Body", err)
1145 }
1146 rw.WriteHeader(StatusOK)
1147 case "/200":
1148
1149
1150 rw.WriteHeader(StatusOK)
1151 case "/500":
1152 rw.WriteHeader(StatusInternalServerError)
1153 case "/keepalive":
1154
1155 _, bufrw, err := rw.(Hijacker).Hijack()
1156 if err != nil {
1157 log.Fatal(err)
1158 }
1159 bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n")
1160 bufrw.WriteString("Content-Length: 0\r\n\r\n")
1161 bufrw.Flush()
1162 case "/timeout":
1163
1164
1165 conn, bufrw, err := rw.(Hijacker).Hijack()
1166 if err != nil {
1167 log.Fatal(err)
1168 }
1169 if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil {
1170 t.Error("Failed to read Body", err)
1171 }
1172 bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n")
1173 bufrw.Flush()
1174 conn.Close()
1175 }
1176
1177 }))
1178 defer ts.Close()
1179
1180 tests := []struct {
1181 path string
1182 body []byte
1183 sent int
1184 status int
1185 }{
1186 {path: "/100", body: []byte("hello"), sent: 5, status: 200},
1187 {path: "/200", body: []byte("hello"), sent: 0, status: 200},
1188 {path: "/500", body: []byte("hello"), sent: 0, status: 500},
1189 {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500},
1190 {path: "/timeout", body: []byte("hello"), sent: 5, status: 200},
1191 }
1192
1193 c := ts.Client()
1194 for i, v := range tests {
1195 tr := &Transport{
1196 ExpectContinueTimeout: 2 * time.Second,
1197 }
1198 defer tr.CloseIdleConnections()
1199 c.Transport = tr
1200 body := bytes.NewReader(v.body)
1201 req, err := NewRequest("PUT", ts.URL+v.path, body)
1202 if err != nil {
1203 t.Fatal(err)
1204 }
1205 req.Header.Set("Expect", "100-continue")
1206 req.ContentLength = int64(len(v.body))
1207
1208 resp, err := c.Do(req)
1209 if err != nil {
1210 t.Fatal(err)
1211 }
1212 resp.Body.Close()
1213
1214 sent := len(v.body) - body.Len()
1215 if v.status != resp.StatusCode {
1216 t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path)
1217 }
1218 if v.sent != sent {
1219 t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path)
1220 }
1221 }
1222 }
1223
1224 func TestSOCKS5Proxy(t *testing.T) {
1225 defer afterTest(t)
1226 ch := make(chan string, 1)
1227 l := newLocalListener(t)
1228 defer l.Close()
1229 defer close(ch)
1230 proxy := func(t *testing.T) {
1231 s, err := l.Accept()
1232 if err != nil {
1233 t.Errorf("socks5 proxy Accept(): %v", err)
1234 return
1235 }
1236 defer s.Close()
1237 var buf [22]byte
1238 if _, err := io.ReadFull(s, buf[:3]); err != nil {
1239 t.Errorf("socks5 proxy initial read: %v", err)
1240 return
1241 }
1242 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1243 t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
1244 return
1245 }
1246 if _, err := s.Write([]byte{5, 0}); err != nil {
1247 t.Errorf("socks5 proxy initial write: %v", err)
1248 return
1249 }
1250 if _, err := io.ReadFull(s, buf[:4]); err != nil {
1251 t.Errorf("socks5 proxy second read: %v", err)
1252 return
1253 }
1254 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1255 t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
1256 return
1257 }
1258 var ipLen int
1259 switch buf[3] {
1260 case 1:
1261 ipLen = net.IPv4len
1262 case 4:
1263 ipLen = net.IPv6len
1264 default:
1265 t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
1266 return
1267 }
1268 if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
1269 t.Errorf("socks5 proxy address read: %v", err)
1270 return
1271 }
1272 ip := net.IP(buf[4 : ipLen+4])
1273 port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
1274 copy(buf[:3], []byte{5, 0, 0})
1275 if _, err := s.Write(buf[:ipLen+6]); err != nil {
1276 t.Errorf("socks5 proxy connect write: %v", err)
1277 return
1278 }
1279 ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
1280
1281
1282 targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
1283 targetConn, err := net.Dial("tcp", targetHost)
1284 if err != nil {
1285 t.Errorf("net.Dial failed")
1286 return
1287 }
1288 go io.Copy(targetConn, s)
1289 io.Copy(s, targetConn)
1290 targetConn.Close()
1291 }
1292
1293 pu, err := url.Parse("socks5://" + l.Addr().String())
1294 if err != nil {
1295 t.Fatal(err)
1296 }
1297
1298 sentinelHeader := "X-Sentinel"
1299 sentinelValue := "12345"
1300 h := HandlerFunc(func(w ResponseWriter, r *Request) {
1301 w.Header().Set(sentinelHeader, sentinelValue)
1302 })
1303 for _, useTLS := range []bool{false, true} {
1304 t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
1305 var ts *httptest.Server
1306 if useTLS {
1307 ts = httptest.NewTLSServer(h)
1308 } else {
1309 ts = httptest.NewServer(h)
1310 }
1311 go proxy(t)
1312 c := ts.Client()
1313 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1314 r, err := c.Head(ts.URL)
1315 if err != nil {
1316 t.Fatal(err)
1317 }
1318 if r.Header.Get(sentinelHeader) != sentinelValue {
1319 t.Errorf("Failed to retrieve sentinel value")
1320 }
1321 var got string
1322 select {
1323 case got = <-ch:
1324 case <-time.After(5 * time.Second):
1325 t.Fatal("timeout connecting to socks5 proxy")
1326 }
1327 ts.Close()
1328 tsu, err := url.Parse(ts.URL)
1329 if err != nil {
1330 t.Fatal(err)
1331 }
1332 want := "proxy for " + tsu.Host
1333 if got != want {
1334 t.Errorf("got %q, want %q", got, want)
1335 }
1336 })
1337 }
1338 }
1339
1340 func TestTransportProxy(t *testing.T) {
1341 defer afterTest(t)
1342 testCases := []struct{ httpsSite, httpsProxy bool }{
1343 {false, false},
1344 {false, true},
1345 {true, false},
1346 {true, true},
1347 }
1348 for _, testCase := range testCases {
1349 httpsSite := testCase.httpsSite
1350 httpsProxy := testCase.httpsProxy
1351 t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) {
1352 siteCh := make(chan *Request, 1)
1353 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1354 siteCh <- r
1355 })
1356 proxyCh := make(chan *Request, 1)
1357 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1358 proxyCh <- r
1359
1360 if r.Method == "CONNECT" {
1361 hijacker, ok := w.(Hijacker)
1362 if !ok {
1363 t.Errorf("hijack not allowed")
1364 return
1365 }
1366 clientConn, _, err := hijacker.Hijack()
1367 if err != nil {
1368 t.Errorf("hijacking failed")
1369 return
1370 }
1371 res := &Response{
1372 StatusCode: StatusOK,
1373 Proto: "HTTP/1.1",
1374 ProtoMajor: 1,
1375 ProtoMinor: 1,
1376 Header: make(Header),
1377 }
1378
1379 targetConn, err := net.Dial("tcp", r.URL.Host)
1380 if err != nil {
1381 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1382 return
1383 }
1384
1385 if err := res.Write(clientConn); err != nil {
1386 t.Errorf("Writing 200 OK failed: %v", err)
1387 return
1388 }
1389
1390 go io.Copy(targetConn, clientConn)
1391 go func() {
1392 io.Copy(clientConn, targetConn)
1393 targetConn.Close()
1394 }()
1395 }
1396 })
1397 var ts *httptest.Server
1398 if httpsSite {
1399 ts = httptest.NewTLSServer(h1)
1400 } else {
1401 ts = httptest.NewServer(h1)
1402 }
1403 var proxy *httptest.Server
1404 if httpsProxy {
1405 proxy = httptest.NewTLSServer(h2)
1406 } else {
1407 proxy = httptest.NewServer(h2)
1408 }
1409
1410 pu, err := url.Parse(proxy.URL)
1411 if err != nil {
1412 t.Fatal(err)
1413 }
1414
1415
1416
1417
1418 c := proxy.Client()
1419 if httpsSite {
1420 c = ts.Client()
1421 }
1422
1423 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1424 if _, err := c.Head(ts.URL); err != nil {
1425 t.Error(err)
1426 }
1427 var got *Request
1428 select {
1429 case got = <-proxyCh:
1430 case <-time.After(5 * time.Second):
1431 t.Fatal("timeout connecting to http proxy")
1432 }
1433 c.Transport.(*Transport).CloseIdleConnections()
1434 ts.Close()
1435 proxy.Close()
1436 if httpsSite {
1437
1438 if got.Method != "CONNECT" {
1439 t.Errorf("Wrong method for secure proxying: %q", got.Method)
1440 }
1441 gotHost := got.URL.Host
1442 pu, err := url.Parse(ts.URL)
1443 if err != nil {
1444 t.Fatal("Invalid site URL")
1445 }
1446 if wantHost := pu.Host; gotHost != wantHost {
1447 t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
1448 }
1449
1450
1451 next := <-siteCh
1452 if next.Method != "HEAD" {
1453 t.Errorf("Wrong method at destination: %s", next.Method)
1454 }
1455 if nextURL := next.URL.String(); nextURL != "/" {
1456 t.Errorf("Wrong URL at destination: %s", nextURL)
1457 }
1458 } else {
1459 if got.Method != "HEAD" {
1460 t.Errorf("Wrong method for destination: %q", got.Method)
1461 }
1462 gotURL := got.URL.String()
1463 wantURL := ts.URL + "/"
1464 if gotURL != wantURL {
1465 t.Errorf("Got URL %q, want %q", gotURL, wantURL)
1466 }
1467 }
1468 })
1469 }
1470 }
1471
1472
1473
1474 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
1475 setParallel(t)
1476 defer afterTest(t)
1477
1478 ctx, cancel := context.WithCancel(context.Background())
1479 defer cancel()
1480
1481 ln := newLocalListener(t)
1482 defer ln.Close()
1483 listenerDone := make(chan struct{})
1484 go func() {
1485 defer close(listenerDone)
1486 c, err := ln.Accept()
1487 if err != nil {
1488 t.Errorf("Accept: %v", err)
1489 return
1490 }
1491 defer c.Close()
1492
1493 br := bufio.NewReader(c)
1494 cr, err := ReadRequest(br)
1495 if err != nil {
1496 t.Errorf("proxy server failed to read CONNECT request")
1497 return
1498 }
1499 if cr.Method != "CONNECT" {
1500 t.Errorf("unexpected method %q", cr.Method)
1501 return
1502 }
1503
1504
1505
1506
1507 cancel()
1508 var buf [1]byte
1509 _, err = br.Read(buf[:])
1510 if err != io.EOF {
1511 t.Errorf("proxy server Read err = %v; want EOF", err)
1512 }
1513 return
1514 }()
1515
1516 c := &Client{
1517 Transport: &Transport{
1518 Proxy: func(*Request) (*url.URL, error) {
1519 return url.Parse("http://" + ln.Addr().String())
1520 },
1521 },
1522 }
1523 req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil)
1524 if err != nil {
1525 t.Fatal(err)
1526 }
1527 _, err = c.Do(req)
1528 if err == nil {
1529 t.Errorf("unexpected Get success")
1530 }
1531
1532
1533
1534
1535 <-listenerDone
1536 }
1537
1538
1539 func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
1540 defer afterTest(t)
1541
1542 var errDial = errors.New("some dial error")
1543
1544 tr := &Transport{
1545 Proxy: func(*Request) (*url.URL, error) {
1546 return url.Parse("http://proxy.fake.tld/")
1547 },
1548 Dial: func(string, string) (net.Conn, error) {
1549 return nil, errDial
1550 },
1551 }
1552 defer tr.CloseIdleConnections()
1553
1554 c := &Client{Transport: tr}
1555 req, _ := NewRequest("GET", "http://fake.tld", nil)
1556 res, err := c.Do(req)
1557 if err == nil {
1558 res.Body.Close()
1559 t.Fatal("wanted a non-nil error")
1560 }
1561
1562 uerr, ok := err.(*url.Error)
1563 if !ok {
1564 t.Fatalf("got %T, want *url.Error", err)
1565 }
1566 oe, ok := uerr.Err.(*net.OpError)
1567 if !ok {
1568 t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
1569 }
1570 want := &net.OpError{
1571 Op: "proxyconnect",
1572 Net: "tcp",
1573 Err: errDial,
1574 }
1575 if !reflect.DeepEqual(oe, want) {
1576 t.Errorf("Got error %#v; want %#v", oe, want)
1577 }
1578 }
1579
1580
1581
1582
1583
1584 func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
1585 setParallel(t)
1586 defer afterTest(t)
1587
1588 proxy := httptest.NewTLSServer(NotFoundHandler())
1589 defer proxy.Close()
1590 c := proxy.Client()
1591
1592 tr := c.Transport.(*Transport)
1593 tr.Proxy = func(*Request) (*url.URL, error) {
1594 u, _ := url.Parse(proxy.URL)
1595 u.User = url.UserPassword("aladdin", "opensesame")
1596 return u, nil
1597 }
1598 h := tr.ProxyConnectHeader
1599 if h == nil {
1600 h = make(Header)
1601 }
1602 tr.ProxyConnectHeader = h.Clone()
1603
1604 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1605 if err != nil {
1606 t.Fatal(err)
1607 }
1608 _, err = c.Do(req)
1609 if err == nil {
1610 t.Errorf("unexpected Get success")
1611 }
1612
1613 if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
1614 t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
1615 }
1616 }
1617
1618
1619
1620
1621
1622 func TestTransportGzipRecursive(t *testing.T) {
1623 defer afterTest(t)
1624 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1625 w.Header().Set("Content-Encoding", "gzip")
1626 w.Write(rgz)
1627 }))
1628 defer ts.Close()
1629
1630 c := ts.Client()
1631 res, err := c.Get(ts.URL)
1632 if err != nil {
1633 t.Fatal(err)
1634 }
1635 body, err := io.ReadAll(res.Body)
1636 if err != nil {
1637 t.Fatal(err)
1638 }
1639 if !bytes.Equal(body, rgz) {
1640 t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
1641 body, rgz)
1642 }
1643 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1644 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1645 }
1646 }
1647
1648
1649
1650 func TestTransportGzipShort(t *testing.T) {
1651 defer afterTest(t)
1652 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1653 w.Header().Set("Content-Encoding", "gzip")
1654 w.Write([]byte{0x1f, 0x8b})
1655 }))
1656 defer ts.Close()
1657
1658 c := ts.Client()
1659 res, err := c.Get(ts.URL)
1660 if err != nil {
1661 t.Fatal(err)
1662 }
1663 defer res.Body.Close()
1664 _, err = io.ReadAll(res.Body)
1665 if err == nil {
1666 t.Fatal("Expect an error from reading a body.")
1667 }
1668 if err != io.ErrUnexpectedEOF {
1669 t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
1670 }
1671 }
1672
1673
1674 func waitNumGoroutine(nmax int) int {
1675 nfinal := runtime.NumGoroutine()
1676 for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
1677 time.Sleep(50 * time.Millisecond)
1678 runtime.GC()
1679 nfinal = runtime.NumGoroutine()
1680 }
1681 return nfinal
1682 }
1683
1684
1685 func TestTransportPersistConnLeak(t *testing.T) {
1686
1687 defer afterTest(t)
1688
1689 const numReq = 25
1690 gotReqCh := make(chan bool, numReq)
1691 unblockCh := make(chan bool, numReq)
1692 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1693 gotReqCh <- true
1694 <-unblockCh
1695 w.Header().Set("Content-Length", "0")
1696 w.WriteHeader(204)
1697 }))
1698 defer ts.Close()
1699 c := ts.Client()
1700 tr := c.Transport.(*Transport)
1701
1702 n0 := runtime.NumGoroutine()
1703
1704 didReqCh := make(chan bool, numReq)
1705 failed := make(chan bool, numReq)
1706 for i := 0; i < numReq; i++ {
1707 go func() {
1708 res, err := c.Get(ts.URL)
1709 didReqCh <- true
1710 if err != nil {
1711 t.Logf("client fetch error: %v", err)
1712 failed <- true
1713 return
1714 }
1715 res.Body.Close()
1716 }()
1717 }
1718
1719
1720 for i := 0; i < numReq; i++ {
1721 select {
1722 case <-gotReqCh:
1723
1724 case <-failed:
1725
1726
1727 }
1728 }
1729
1730 nhigh := runtime.NumGoroutine()
1731
1732
1733 close(unblockCh)
1734
1735
1736 for i := 0; i < numReq; i++ {
1737 <-didReqCh
1738 }
1739
1740 tr.CloseIdleConnections()
1741 nfinal := waitNumGoroutine(n0 + 5)
1742
1743 growth := nfinal - n0
1744
1745
1746
1747 if int(growth) > 5 {
1748 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
1749 t.Error("too many new goroutines")
1750 }
1751 }
1752
1753
1754
1755 func TestTransportPersistConnLeakShortBody(t *testing.T) {
1756
1757 defer afterTest(t)
1758 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1759 }))
1760 defer ts.Close()
1761 c := ts.Client()
1762 tr := c.Transport.(*Transport)
1763
1764 n0 := runtime.NumGoroutine()
1765 body := []byte("Hello")
1766 for i := 0; i < 20; i++ {
1767 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
1768 if err != nil {
1769 t.Fatal(err)
1770 }
1771 req.ContentLength = int64(len(body) - 2)
1772 _, err = c.Do(req)
1773 if err == nil {
1774 t.Fatal("Expect an error from writing too long of a body.")
1775 }
1776 }
1777 nhigh := runtime.NumGoroutine()
1778 tr.CloseIdleConnections()
1779 nfinal := waitNumGoroutine(n0 + 5)
1780
1781 growth := nfinal - n0
1782
1783
1784
1785 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
1786 if int(growth) > 5 {
1787 t.Error("too many new goroutines")
1788 }
1789 }
1790
1791
1792 type countedConn struct {
1793 net.Conn
1794 }
1795
1796
1797 type countingDialer struct {
1798 dialer net.Dialer
1799 mu sync.Mutex
1800 total, live int64
1801 }
1802
1803 func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
1804 conn, err := d.dialer.DialContext(ctx, network, address)
1805 if err != nil {
1806 return nil, err
1807 }
1808
1809 counted := new(countedConn)
1810 counted.Conn = conn
1811
1812 d.mu.Lock()
1813 defer d.mu.Unlock()
1814 d.total++
1815 d.live++
1816
1817 runtime.SetFinalizer(counted, d.decrement)
1818 return counted, nil
1819 }
1820
1821 func (d *countingDialer) decrement(*countedConn) {
1822 d.mu.Lock()
1823 defer d.mu.Unlock()
1824 d.live--
1825 }
1826
1827 func (d *countingDialer) Read() (total, live int64) {
1828 d.mu.Lock()
1829 defer d.mu.Unlock()
1830 return d.total, d.live
1831 }
1832
1833 func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
1834 defer afterTest(t)
1835
1836 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1837
1838 conn, _, err := w.(Hijacker).Hijack()
1839 if err != nil {
1840 t.Errorf("Hijack failed unexpectedly: %v", err)
1841 return
1842 }
1843 conn.Close()
1844 }))
1845 defer ts.Close()
1846
1847 var d countingDialer
1848 c := ts.Client()
1849 c.Transport.(*Transport).DialContext = d.DialContext
1850
1851 body := []byte("Hello")
1852 for i := 0; ; i++ {
1853 total, live := d.Read()
1854 if live < total {
1855 break
1856 }
1857 if i >= 1<<12 {
1858 t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
1859 }
1860
1861 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
1862 if err != nil {
1863 t.Fatal(err)
1864 }
1865 _, err = c.Do(req)
1866 if err == nil {
1867 t.Fatal("expected broken connection")
1868 }
1869
1870 runtime.GC()
1871 }
1872 }
1873
1874 type countedContext struct {
1875 context.Context
1876 }
1877
1878 type contextCounter struct {
1879 mu sync.Mutex
1880 live int64
1881 }
1882
1883 func (cc *contextCounter) Track(ctx context.Context) context.Context {
1884 counted := new(countedContext)
1885 counted.Context = ctx
1886 cc.mu.Lock()
1887 defer cc.mu.Unlock()
1888 cc.live++
1889 runtime.SetFinalizer(counted, cc.decrement)
1890 return counted
1891 }
1892
1893 func (cc *contextCounter) decrement(*countedContext) {
1894 cc.mu.Lock()
1895 defer cc.mu.Unlock()
1896 cc.live--
1897 }
1898
1899 func (cc *contextCounter) Read() (live int64) {
1900 cc.mu.Lock()
1901 defer cc.mu.Unlock()
1902 return cc.live
1903 }
1904
1905 func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
1906 defer afterTest(t)
1907
1908 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1909 runtime.Gosched()
1910 w.WriteHeader(StatusOK)
1911 }))
1912 defer ts.Close()
1913
1914 c := ts.Client()
1915 c.Transport.(*Transport).MaxConnsPerHost = 1
1916
1917 ctx := context.Background()
1918 body := []byte("Hello")
1919 doPosts := func(cc *contextCounter) {
1920 var wg sync.WaitGroup
1921 for n := 64; n > 0; n-- {
1922 wg.Add(1)
1923 go func() {
1924 defer wg.Done()
1925
1926 ctx := cc.Track(ctx)
1927 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
1928 if err != nil {
1929 t.Error(err)
1930 }
1931
1932 _, err = c.Do(req.WithContext(ctx))
1933 if err != nil {
1934 t.Errorf("Do failed with error: %v", err)
1935 }
1936 }()
1937 }
1938 wg.Wait()
1939 }
1940
1941 var initialCC contextCounter
1942 doPosts(&initialCC)
1943
1944
1945
1946
1947 var flushCC contextCounter
1948 for i := 0; ; i++ {
1949 live := initialCC.Read()
1950 if live == 0 {
1951 break
1952 }
1953 if i >= 100 {
1954 t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
1955 }
1956 doPosts(&flushCC)
1957 runtime.GC()
1958 }
1959 }
1960
1961
1962 func TestTransportIdleConnCrash(t *testing.T) {
1963 defer afterTest(t)
1964 var tr *Transport
1965
1966 unblockCh := make(chan bool, 1)
1967 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1968 <-unblockCh
1969 tr.CloseIdleConnections()
1970 }))
1971 defer ts.Close()
1972 c := ts.Client()
1973 tr = c.Transport.(*Transport)
1974
1975 didreq := make(chan bool)
1976 go func() {
1977 res, err := c.Get(ts.URL)
1978 if err != nil {
1979 t.Error(err)
1980 } else {
1981 res.Body.Close()
1982 }
1983 didreq <- true
1984 }()
1985 unblockCh <- true
1986 <-didreq
1987 }
1988
1989
1990
1991
1992
1993 func TestIssue3644(t *testing.T) {
1994 defer afterTest(t)
1995 const numFoos = 5000
1996 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1997 w.Header().Set("Connection", "close")
1998 for i := 0; i < numFoos; i++ {
1999 w.Write([]byte("foo "))
2000 }
2001 }))
2002 defer ts.Close()
2003 c := ts.Client()
2004 res, err := c.Get(ts.URL)
2005 if err != nil {
2006 t.Fatal(err)
2007 }
2008 defer res.Body.Close()
2009 bs, err := io.ReadAll(res.Body)
2010 if err != nil {
2011 t.Fatal(err)
2012 }
2013 if len(bs) != numFoos*len("foo ") {
2014 t.Errorf("unexpected response length")
2015 }
2016 }
2017
2018
2019
2020 func TestIssue3595(t *testing.T) {
2021 setParallel(t)
2022 defer afterTest(t)
2023 const deniedMsg = "sorry, denied."
2024 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
2025 Error(w, deniedMsg, StatusUnauthorized)
2026 }))
2027 defer ts.Close()
2028 c := ts.Client()
2029 res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
2030 if err != nil {
2031 t.Errorf("Post: %v", err)
2032 return
2033 }
2034 got, err := io.ReadAll(res.Body)
2035 if err != nil {
2036 t.Fatalf("Body ReadAll: %v", err)
2037 }
2038 if !strings.Contains(string(got), deniedMsg) {
2039 t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
2040 }
2041 }
2042
2043
2044
2045 func TestChunkedNoContent(t *testing.T) {
2046 defer afterTest(t)
2047 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
2048 w.WriteHeader(StatusNoContent)
2049 }))
2050 defer ts.Close()
2051
2052 c := ts.Client()
2053 for _, closeBody := range []bool{true, false} {
2054 const n = 4
2055 for i := 1; i <= n; i++ {
2056 res, err := c.Get(ts.URL)
2057 if err != nil {
2058 t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
2059 } else {
2060 if closeBody {
2061 res.Body.Close()
2062 }
2063 }
2064 }
2065 }
2066 }
2067
2068 func TestTransportConcurrency(t *testing.T) {
2069
2070 defer afterTest(t)
2071 maxProcs, numReqs := 16, 500
2072 if testing.Short() {
2073 maxProcs, numReqs = 4, 50
2074 }
2075 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
2076 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
2077 fmt.Fprintf(w, "%v", r.FormValue("echo"))
2078 }))
2079 defer ts.Close()
2080
2081 var wg sync.WaitGroup
2082 wg.Add(numReqs)
2083
2084
2085
2086
2087
2088
2089
2090 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
2091 defer SetPendingDialHooks(nil, nil)
2092
2093 c := ts.Client()
2094 reqs := make(chan string)
2095 defer close(reqs)
2096
2097 for i := 0; i < maxProcs*2; i++ {
2098 go func() {
2099 for req := range reqs {
2100 res, err := c.Get(ts.URL + "/?echo=" + req)
2101 if err != nil {
2102 t.Errorf("error on req %s: %v", req, err)
2103 wg.Done()
2104 continue
2105 }
2106 all, err := io.ReadAll(res.Body)
2107 if err != nil {
2108 t.Errorf("read error on req %s: %v", req, err)
2109 wg.Done()
2110 continue
2111 }
2112 if string(all) != req {
2113 t.Errorf("body of req %s = %q; want %q", req, all, req)
2114 }
2115 res.Body.Close()
2116 wg.Done()
2117 }
2118 }()
2119 }
2120 for i := 0; i < numReqs; i++ {
2121 reqs <- fmt.Sprintf("request-%d", i)
2122 }
2123 wg.Wait()
2124 }
2125
2126 func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
2127 setParallel(t)
2128 defer afterTest(t)
2129 const debug = false
2130 mux := NewServeMux()
2131 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2132 io.Copy(w, neverEnding('a'))
2133 })
2134 ts := httptest.NewServer(mux)
2135 defer ts.Close()
2136 timeout := 100 * time.Millisecond
2137
2138 c := ts.Client()
2139 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2140 conn, err := net.Dial(n, addr)
2141 if err != nil {
2142 return nil, err
2143 }
2144 conn.SetDeadline(time.Now().Add(timeout))
2145 if debug {
2146 conn = NewLoggingConn("client", conn)
2147 }
2148 return conn, nil
2149 }
2150
2151 getFailed := false
2152 nRuns := 5
2153 if testing.Short() {
2154 nRuns = 1
2155 }
2156 for i := 0; i < nRuns; i++ {
2157 if debug {
2158 println("run", i+1, "of", nRuns)
2159 }
2160 sres, err := c.Get(ts.URL + "/get")
2161 if err != nil {
2162 if !getFailed {
2163
2164 getFailed = true
2165 t.Logf("increasing timeout")
2166 i--
2167 timeout *= 10
2168 continue
2169 }
2170 t.Errorf("Error issuing GET: %v", err)
2171 break
2172 }
2173 _, err = io.Copy(io.Discard, sres.Body)
2174 if err == nil {
2175 t.Errorf("Unexpected successful copy")
2176 break
2177 }
2178 }
2179 if debug {
2180 println("tests complete; waiting for handlers to finish")
2181 }
2182 }
2183
2184 func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
2185 setParallel(t)
2186 defer afterTest(t)
2187 const debug = false
2188 mux := NewServeMux()
2189 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2190 io.Copy(w, neverEnding('a'))
2191 })
2192 mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
2193 defer r.Body.Close()
2194 io.Copy(io.Discard, r.Body)
2195 })
2196 ts := httptest.NewServer(mux)
2197 timeout := 100 * time.Millisecond
2198
2199 c := ts.Client()
2200 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2201 conn, err := net.Dial(n, addr)
2202 if err != nil {
2203 return nil, err
2204 }
2205 conn.SetDeadline(time.Now().Add(timeout))
2206 if debug {
2207 conn = NewLoggingConn("client", conn)
2208 }
2209 return conn, nil
2210 }
2211
2212 getFailed := false
2213 nRuns := 5
2214 if testing.Short() {
2215 nRuns = 1
2216 }
2217 for i := 0; i < nRuns; i++ {
2218 if debug {
2219 println("run", i+1, "of", nRuns)
2220 }
2221 sres, err := c.Get(ts.URL + "/get")
2222 if err != nil {
2223 if !getFailed {
2224
2225 getFailed = true
2226 t.Logf("increasing timeout")
2227 i--
2228 timeout *= 10
2229 continue
2230 }
2231 t.Errorf("Error issuing GET: %v", err)
2232 break
2233 }
2234 req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
2235 _, err = c.Do(req)
2236 if err == nil {
2237 sres.Body.Close()
2238 t.Errorf("Unexpected successful PUT")
2239 break
2240 }
2241 sres.Body.Close()
2242 }
2243 if debug {
2244 println("tests complete; waiting for handlers to finish")
2245 }
2246 ts.Close()
2247 }
2248
2249 func TestTransportResponseHeaderTimeout(t *testing.T) {
2250 setParallel(t)
2251 defer afterTest(t)
2252 if testing.Short() {
2253 t.Skip("skipping timeout test in -short mode")
2254 }
2255 inHandler := make(chan bool, 1)
2256 mux := NewServeMux()
2257 mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
2258 inHandler <- true
2259 })
2260 mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
2261 inHandler <- true
2262 time.Sleep(2 * time.Second)
2263 })
2264 ts := httptest.NewServer(mux)
2265 defer ts.Close()
2266
2267 c := ts.Client()
2268 c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond
2269
2270 tests := []struct {
2271 path string
2272 want int
2273 wantErr string
2274 }{
2275 {path: "/fast", want: 200},
2276 {path: "/slow", wantErr: "timeout awaiting response headers"},
2277 {path: "/fast", want: 200},
2278 }
2279 for i, tt := range tests {
2280 req, _ := NewRequest("GET", ts.URL+tt.path, nil)
2281 req = req.WithT(t)
2282 res, err := c.Do(req)
2283 select {
2284 case <-inHandler:
2285 case <-time.After(5 * time.Second):
2286 t.Errorf("never entered handler for test index %d, %s", i, tt.path)
2287 continue
2288 }
2289 if err != nil {
2290 uerr, ok := err.(*url.Error)
2291 if !ok {
2292 t.Errorf("error is not an url.Error; got: %#v", err)
2293 continue
2294 }
2295 nerr, ok := uerr.Err.(net.Error)
2296 if !ok {
2297 t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
2298 continue
2299 }
2300 if !nerr.Timeout() {
2301 t.Errorf("want timeout error; got: %q", nerr)
2302 continue
2303 }
2304 if strings.Contains(err.Error(), tt.wantErr) {
2305 continue
2306 }
2307 t.Errorf("%d. unexpected error: %v", i, err)
2308 continue
2309 }
2310 if tt.wantErr != "" {
2311 t.Errorf("%d. no error. expected error: %v", i, tt.wantErr)
2312 continue
2313 }
2314 if res.StatusCode != tt.want {
2315 t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want)
2316 }
2317 }
2318 }
2319
2320 func TestTransportCancelRequest(t *testing.T) {
2321 setParallel(t)
2322 defer afterTest(t)
2323 if testing.Short() {
2324 t.Skip("skipping test in -short mode")
2325 }
2326 unblockc := make(chan bool)
2327 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
2328 fmt.Fprintf(w, "Hello")
2329 w.(Flusher).Flush()
2330 <-unblockc
2331 }))
2332 defer ts.Close()
2333 defer close(unblockc)
2334
2335 c := ts.Client()
2336 tr := c.Transport.(*Transport)
2337
2338 req, _ := NewRequest("GET", ts.URL, nil)
2339 res, err := c.Do(req)
2340 if err != nil {
2341 t.Fatal(err)
2342 }
2343 go func() {
2344 time.Sleep(1 * time.Second)
2345 tr.CancelRequest(req)
2346 }()
2347 t0 := time.Now()
2348 body, err := io.ReadAll(res.Body)
2349 d := time.Since(t0)
2350
2351 if err != ExportErrRequestCanceled {
2352 t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
2353 }
2354 if string(body) != "Hello" {
2355 t.Errorf("Body = %q; want Hello", body)
2356 }
2357 if d < 500*time.Millisecond {
2358 t.Errorf("expected ~1 second delay; got %v", d)
2359 }
2360
2361
2362 for tries := 5; tries > 0; tries-- {
2363 n := tr.NumPendingRequestsForTesting()
2364 if n == 0 {
2365 break
2366 }
2367 time.Sleep(100 * time.Millisecond)
2368 if tries == 1 {
2369 t.Errorf("pending requests = %d; want 0", n)
2370 }
2371 }
2372 }
2373
2374 func testTransportCancelRequestInDo(t *testing.T, body io.Reader) {
2375 setParallel(t)
2376 defer afterTest(t)
2377 if testing.Short() {
2378 t.Skip("skipping test in -short mode")
2379 }
2380 unblockc := make(chan bool)
2381 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
2382 <-unblockc
2383 }))
2384 defer ts.Close()
2385 defer close(unblockc)
2386
2387 c := ts.Client()
2388 tr := c.Transport.(*Transport)
2389
2390 donec := make(chan bool)
2391 req, _ := NewRequest("GET", ts.URL, body)
2392 go func() {
2393 defer close(donec)
2394 c.Do(req)
2395 }()
2396 start := time.Now()
2397 timeout := 10 * time.Second
2398 for time.Since(start) < timeout {
2399 time.Sleep(100 * time.Millisecond)
2400 tr.CancelRequest(req)
2401 select {
2402 case <-donec:
2403 return
2404 default:
2405 }
2406 }
2407 t.Errorf("Do of canceled request has not returned after %v", timeout)
2408 }
2409
2410 func TestTransportCancelRequestInDo(t *testing.T) {
2411 testTransportCancelRequestInDo(t, nil)
2412 }
2413
2414 func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
2415 testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0}))
2416 }
2417
2418 func TestTransportCancelRequestInDial(t *testing.T) {
2419 defer afterTest(t)
2420 if testing.Short() {
2421 t.Skip("skipping test in -short mode")
2422 }
2423 var logbuf bytes.Buffer
2424 eventLog := log.New(&logbuf, "", 0)
2425
2426 unblockDial := make(chan bool)
2427 defer close(unblockDial)
2428
2429 inDial := make(chan bool)
2430 tr := &Transport{
2431 Dial: func(network, addr string) (net.Conn, error) {
2432 eventLog.Println("dial: blocking")
2433 if !<-inDial {
2434 return nil, errors.New("main Test goroutine exited")
2435 }
2436 <-unblockDial
2437 return nil, errors.New("nope")
2438 },
2439 }
2440 cl := &Client{Transport: tr}
2441 gotres := make(chan bool)
2442 req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
2443 go func() {
2444 _, err := cl.Do(req)
2445 eventLog.Printf("Get = %v", err)
2446 gotres <- true
2447 }()
2448
2449 select {
2450 case inDial <- true:
2451 case <-time.After(5 * time.Second):
2452 close(inDial)
2453 t.Fatal("timeout; never saw blocking dial")
2454 }
2455
2456 eventLog.Printf("canceling")
2457 tr.CancelRequest(req)
2458 tr.CancelRequest(req)
2459
2460 select {
2461 case <-gotres:
2462 case <-time.After(5 * time.Second):
2463 panic("hang. events are: " + logbuf.String())
2464 }
2465
2466 got := logbuf.String()
2467 want := `dial: blocking
2468 canceling
2469 Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection
2470 `
2471 if got != want {
2472 t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
2473 }
2474 }
2475
2476 func TestCancelRequestWithChannel(t *testing.T) {
2477 setParallel(t)
2478 defer afterTest(t)
2479 if testing.Short() {
2480 t.Skip("skipping test in -short mode")
2481 }
2482 unblockc := make(chan bool)
2483 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
2484 fmt.Fprintf(w, "Hello")
2485 w.(Flusher).Flush()
2486 <-unblockc
2487 }))
2488 defer ts.Close()
2489 defer close(unblockc)
2490
2491 c := ts.Client()
2492 tr := c.Transport.(*Transport)
2493
2494 req, _ := NewRequest("GET", ts.URL, nil)
2495 ch := make(chan struct{})
2496 req.Cancel = ch
2497
2498 res, err := c.Do(req)
2499 if err != nil {
2500 t.Fatal(err)
2501 }
2502 go func() {
2503 time.Sleep(1 * time.Second)
2504 close(ch)
2505 }()
2506 t0 := time.Now()
2507 body, err := io.ReadAll(res.Body)
2508 d := time.Since(t0)
2509
2510 if err != ExportErrRequestCanceled {
2511 t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
2512 }
2513 if string(body) != "Hello" {
2514 t.Errorf("Body = %q; want Hello", body)
2515 }
2516 if d < 500*time.Millisecond {
2517 t.Errorf("expected ~1 second delay; got %v", d)
2518 }
2519
2520
2521 for tries := 5; tries > 0; tries-- {
2522 n := tr.NumPendingRequestsForTesting()
2523 if n == 0 {
2524 break
2525 }
2526 time.Sleep(100 * time.Millisecond)
2527 if tries == 1 {
2528 t.Errorf("pending requests = %d; want 0", n)
2529 }
2530 }
2531 }
2532
2533 func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) {
2534 testCancelRequestWithChannelBeforeDo(t, false)
2535 }
2536 func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) {
2537 testCancelRequestWithChannelBeforeDo(t, true)
2538 }
2539 func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) {
2540 setParallel(t)
2541 defer afterTest(t)
2542 unblockc := make(chan bool)
2543 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
2544 <-unblockc
2545 }))
2546 defer ts.Close()
2547 defer close(unblockc)
2548
2549 c := ts.Client()
2550
2551 req, _ := NewRequest("GET", ts.URL, nil)
2552 if withCtx {
2553 ctx, cancel := context.WithCancel(context.Background())
2554 cancel()
2555 req = req.WithContext(ctx)
2556 } else {
2557 ch := make(chan struct{})
2558 req.Cancel = ch
2559 close(ch)
2560 }
2561
2562 _, err := c.Do(req)
2563 if ue, ok := err.(*url.Error); ok {
2564 err = ue.Err
2565 }
2566 if withCtx {
2567 if err != context.Canceled {
2568 t.Errorf("Do error = %v; want %v", err, context.Canceled)
2569 }
2570 } else {
2571 if err == nil || !strings.Contains(err.Error(), "canceled") {
2572 t.Errorf("Do error = %v; want cancellation", err)
2573 }
2574 }
2575 }
2576
2577
2578 func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
2579 defer afterTest(t)
2580
2581 serverConnCh := make(chan net.Conn, 1)
2582 tr := &Transport{
2583 Dial: func(network, addr string) (net.Conn, error) {
2584 cc, sc := net.Pipe()
2585 serverConnCh <- sc
2586 return cc, nil
2587 },
2588 }
2589 defer tr.CloseIdleConnections()
2590 errc := make(chan error, 1)
2591 req, _ := NewRequest("GET", "http://example.com/", nil)
2592 go func() {
2593 _, err := tr.RoundTrip(req)
2594 errc <- err
2595 }()
2596
2597 sc := <-serverConnCh
2598 verb := make([]byte, 3)
2599 if _, err := io.ReadFull(sc, verb); err != nil {
2600 t.Errorf("Error reading HTTP verb from server: %v", err)
2601 }
2602 if string(verb) != "GET" {
2603 t.Errorf("server received %q; want GET", verb)
2604 }
2605 defer sc.Close()
2606
2607 tr.CancelRequest(req)
2608
2609 err := <-errc
2610 if err == nil {
2611 t.Fatalf("unexpected success from RoundTrip")
2612 }
2613 if err != ExportErrRequestCanceled {
2614 t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err)
2615 }
2616 }
2617
2618
2619
2620
2621 func TestTransportCloseResponseBody(t *testing.T) {
2622 defer afterTest(t)
2623 writeErr := make(chan error, 1)
2624 msg := []byte("young\n")
2625 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
2626 for {
2627 _, err := w.Write(msg)
2628 if err != nil {
2629 writeErr <- err
2630 return
2631 }
2632 w.(Flusher).Flush()
2633 }
2634 }))
2635 defer ts.Close()
2636
2637 c := ts.Client()
2638 tr := c.Transport.(*Transport)
2639
2640 req, _ := NewRequest("GET", ts.URL, nil)
2641 defer tr.CancelRequest(req)
2642
2643 res, err := c.Do(req)
2644 if err != nil {
2645 t.Fatal(err)
2646 }
2647
2648 const repeats = 3
2649 buf := make([]byte, len(msg)*repeats)
2650 want := bytes.Repeat(msg, repeats)
2651
2652 _, err = io.ReadFull(res.Body, buf)
2653 if err != nil {
2654 t.Fatal(err)
2655 }
2656 if !bytes.Equal(buf, want) {
2657 t.Fatalf("read %q; want %q", buf, want)
2658 }
2659 didClose := make(chan error, 1)
2660 go func() {
2661 didClose <- res.Body.Close()
2662 }()
2663 select {
2664 case err := <-didClose:
2665 if err != nil {
2666 t.Errorf("Close = %v", err)
2667 }
2668 case <-time.After(10 * time.Second):
2669 t.Fatal("too long waiting for close")
2670 }
2671 select {
2672 case err := <-writeErr:
2673 if err == nil {
2674 t.Errorf("expected non-nil write error")
2675 }
2676 case <-time.After(10 * time.Second):
2677 t.Fatal("too long waiting for write error")
2678 }
2679 }
2680
2681 type fooProto struct{}
2682
2683 func (fooProto) RoundTrip(req *Request) (*Response, error) {
2684 res := &Response{
2685 Status: "200 OK",
2686 StatusCode: 200,
2687 Header: make(Header),
2688 Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
2689 }
2690 return res, nil
2691 }
2692
2693 func TestTransportAltProto(t *testing.T) {
2694 defer afterTest(t)
2695 tr := &Transport{}
2696 c := &Client{Transport: tr}
2697 tr.RegisterProtocol("foo", fooProto{})
2698 res, err := c.Get("foo://bar.com/path")
2699 if err != nil {
2700 t.Fatal(err)
2701 }
2702 bodyb, err := io.ReadAll(res.Body)
2703 if err != nil {
2704 t.Fatal(err)
2705 }
2706 body := string(bodyb)
2707 if e := "You wanted foo://bar.com/path"; body != e {
2708 t.Errorf("got response %q, want %q", body, e)
2709 }
2710 }
2711
2712 func TestTransportNoHost(t *testing.T) {
2713 defer afterTest(t)
2714 tr := &Transport{}
2715 _, err := tr.RoundTrip(&Request{
2716 Header: make(Header),
2717 URL: &url.URL{
2718 Scheme: "http",
2719 },
2720 })
2721 want := "http: no Host in request URL"
2722 if got := fmt.Sprint(err); got != want {
2723 t.Errorf("error = %v; want %q", err, want)
2724 }
2725 }
2726
2727
2728 func TestTransportEmptyMethod(t *testing.T) {
2729 req, _ := NewRequest("GET", "http://foo.com/", nil)
2730 req.Method = ""
2731 got, err := httputil.DumpRequestOut(req, false)
2732 if err != nil {
2733 t.Fatal(err)
2734 }
2735 if !strings.Contains(string(got), "GET ") {
2736 t.Fatalf("expected substring 'GET '; got: %s", got)
2737 }
2738 }
2739
2740 func TestTransportSocketLateBinding(t *testing.T) {
2741 setParallel(t)
2742 defer afterTest(t)
2743
2744 mux := NewServeMux()
2745 fooGate := make(chan bool, 1)
2746 mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
2747 w.Header().Set("foo-ipport", r.RemoteAddr)
2748 w.(Flusher).Flush()
2749 <-fooGate
2750 })
2751 mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
2752 w.Header().Set("bar-ipport", r.RemoteAddr)
2753 })
2754 ts := httptest.NewServer(mux)
2755 defer ts.Close()
2756
2757 dialGate := make(chan bool, 1)
2758 c := ts.Client()
2759 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2760 if <-dialGate {
2761 return net.Dial(n, addr)
2762 }
2763 return nil, errors.New("manually closed")
2764 }
2765
2766 dialGate <- true
2767 fooRes, err := c.Get(ts.URL + "/foo")
2768 if err != nil {
2769 t.Fatal(err)
2770 }
2771 fooAddr := fooRes.Header.Get("foo-ipport")
2772 if fooAddr == "" {
2773 t.Fatal("No addr on /foo request")
2774 }
2775 time.AfterFunc(200*time.Millisecond, func() {
2776
2777
2778 fooGate <- true
2779 io.Copy(io.Discard, fooRes.Body)
2780 fooRes.Body.Close()
2781 })
2782
2783 barRes, err := c.Get(ts.URL + "/bar")
2784 if err != nil {
2785 t.Fatal(err)
2786 }
2787 barAddr := barRes.Header.Get("bar-ipport")
2788 if barAddr != fooAddr {
2789 t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
2790 }
2791 barRes.Body.Close()
2792 dialGate <- false
2793 }
2794
2795
2796 func TestTransportReading100Continue(t *testing.T) {
2797 defer afterTest(t)
2798
2799 const numReqs = 5
2800 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
2801 reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
2802
2803 send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
2804 defer w.Close()
2805 defer r.Close()
2806 br := bufio.NewReader(r)
2807 n := 0
2808 for {
2809 n++
2810 req, err := ReadRequest(br)
2811 if err == io.EOF {
2812 return
2813 }
2814 if err != nil {
2815 t.Error(err)
2816 return
2817 }
2818 slurp, err := io.ReadAll(req.Body)
2819 if err != nil {
2820 t.Errorf("Server request body slurp: %v", err)
2821 return
2822 }
2823 id := req.Header.Get("Request-Id")
2824 resCode := req.Header.Get("X-Want-Response-Code")
2825 if resCode == "" {
2826 resCode = "100 Continue"
2827 if string(slurp) != reqBody(n) {
2828 t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
2829 }
2830 }
2831 body := fmt.Sprintf("Response number %d", n)
2832 v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
2833 Date: Thu, 28 Feb 2013 17:55:41 GMT
2834
2835 HTTP/1.1 200 OK
2836 Content-Type: text/html
2837 Echo-Request-Id: %s
2838 Content-Length: %d
2839
2840 %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
2841 w.Write(v)
2842 if id == reqID(numReqs) {
2843 return
2844 }
2845 }
2846
2847 }
2848
2849 tr := &Transport{
2850 Dial: func(n, addr string) (net.Conn, error) {
2851 sr, sw := io.Pipe()
2852 cr, cw := io.Pipe()
2853 conn := &rwTestConn{
2854 Reader: cr,
2855 Writer: sw,
2856 closeFunc: func() error {
2857 sw.Close()
2858 cw.Close()
2859 return nil
2860 },
2861 }
2862 go send100Response(cw, sr)
2863 return conn, nil
2864 },
2865 DisableKeepAlives: false,
2866 }
2867 defer tr.CloseIdleConnections()
2868 c := &Client{Transport: tr}
2869
2870 testResponse := func(req *Request, name string, wantCode int) {
2871 t.Helper()
2872 res, err := c.Do(req)
2873 if err != nil {
2874 t.Fatalf("%s: Do: %v", name, err)
2875 }
2876 if res.StatusCode != wantCode {
2877 t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
2878 }
2879 if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
2880 t.Errorf("%s: response id %q != request id %q", name, idBack, id)
2881 }
2882 _, err = io.ReadAll(res.Body)
2883 if err != nil {
2884 t.Fatalf("%s: Slurp error: %v", name, err)
2885 }
2886 }
2887
2888
2889 for i := 1; i <= numReqs; i++ {
2890 req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
2891 req.Header.Set("Request-Id", reqID(i))
2892 testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
2893 }
2894 }
2895
2896
2897
2898 func TestTransportIgnore1xxResponses(t *testing.T) {
2899 setParallel(t)
2900 defer afterTest(t)
2901 cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2902 conn, buf, _ := w.(Hijacker).Hijack()
2903 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
2904 buf.Flush()
2905 conn.Close()
2906 }))
2907 defer cst.close()
2908 cst.tr.DisableKeepAlives = true
2909
2910 var got bytes.Buffer
2911
2912 req, _ := NewRequest("GET", cst.ts.URL, nil)
2913 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
2914 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
2915 fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
2916 return nil
2917 },
2918 }))
2919 res, err := cst.c.Do(req)
2920 if err != nil {
2921 t.Fatal(err)
2922 }
2923 defer res.Body.Close()
2924
2925 res.Write(&got)
2926 want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
2927 if got.String() != want {
2928 t.Errorf(" got: %q\nwant: %q\n", got.Bytes(), want)
2929 }
2930 }
2931
2932 func TestTransportLimits1xxResponses(t *testing.T) {
2933 setParallel(t)
2934 defer afterTest(t)
2935 cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2936 conn, buf, _ := w.(Hijacker).Hijack()
2937 for i := 0; i < 10; i++ {
2938 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n"))
2939 }
2940 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
2941 buf.Flush()
2942 conn.Close()
2943 }))
2944 defer cst.close()
2945 cst.tr.DisableKeepAlives = true
2946
2947 res, err := cst.c.Get(cst.ts.URL)
2948 if res != nil {
2949 defer res.Body.Close()
2950 }
2951 got := fmt.Sprint(err)
2952 wantSub := "too many 1xx informational responses"
2953 if !strings.Contains(got, wantSub) {
2954 t.Errorf("Get error = %v; want substring %q", err, wantSub)
2955 }
2956 }
2957
2958
2959
2960 func TestTransportTreat101Terminal(t *testing.T) {
2961 setParallel(t)
2962 defer afterTest(t)
2963 cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2964 conn, buf, _ := w.(Hijacker).Hijack()
2965 buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
2966 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
2967 buf.Flush()
2968 conn.Close()
2969 }))
2970 defer cst.close()
2971 res, err := cst.c.Get(cst.ts.URL)
2972 if err != nil {
2973 t.Fatal(err)
2974 }
2975 defer res.Body.Close()
2976 if res.StatusCode != StatusSwitchingProtocols {
2977 t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
2978 }
2979 }
2980
2981 type proxyFromEnvTest struct {
2982 req string
2983
2984 env string
2985 httpsenv string
2986 noenv string
2987 reqmeth string
2988
2989 want string
2990 wanterr error
2991 }
2992
2993 func (t proxyFromEnvTest) String() string {
2994 var buf bytes.Buffer
2995 space := func() {
2996 if buf.Len() > 0 {
2997 buf.WriteByte(' ')
2998 }
2999 }
3000 if t.env != "" {
3001 fmt.Fprintf(&buf, "http_proxy=%q", t.env)
3002 }
3003 if t.httpsenv != "" {
3004 space()
3005 fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
3006 }
3007 if t.noenv != "" {
3008 space()
3009 fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
3010 }
3011 if t.reqmeth != "" {
3012 space()
3013 fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
3014 }
3015 req := "http://example.com"
3016 if t.req != "" {
3017 req = t.req
3018 }
3019 space()
3020 fmt.Fprintf(&buf, "req=%q", req)
3021 return strings.TrimSpace(buf.String())
3022 }
3023
3024 var proxyFromEnvTests = []proxyFromEnvTest{
3025 {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3026 {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
3027 {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
3028 {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
3029 {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3030 {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
3031 {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
3032
3033
3034 {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
3035
3036 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
3037 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
3038
3039
3040
3041 {env: "http://10.1.2.3:8080", reqmeth: "POST",
3042 want: "<nil>",
3043 wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
3044
3045 {want: "<nil>"},
3046
3047 {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
3048 {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3049 {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3050 {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
3051 {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3052 }
3053
3054 func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
3055 t.Helper()
3056 reqURL := tt.req
3057 if reqURL == "" {
3058 reqURL = "http://example.com"
3059 }
3060 req, _ := NewRequest("GET", reqURL, nil)
3061 url, err := proxyForRequest(req)
3062 if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
3063 t.Errorf("%v: got error = %q, want %q", tt, g, e)
3064 return
3065 }
3066 if got := fmt.Sprintf("%s", url); got != tt.want {
3067 t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
3068 }
3069 }
3070
3071 func TestProxyFromEnvironment(t *testing.T) {
3072 ResetProxyEnv()
3073 defer ResetProxyEnv()
3074 for _, tt := range proxyFromEnvTests {
3075 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3076 os.Setenv("HTTP_PROXY", tt.env)
3077 os.Setenv("HTTPS_PROXY", tt.httpsenv)
3078 os.Setenv("NO_PROXY", tt.noenv)
3079 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3080 ResetCachedEnvironment()
3081 return ProxyFromEnvironment(req)
3082 })
3083 }
3084 }
3085
3086 func TestProxyFromEnvironmentLowerCase(t *testing.T) {
3087 ResetProxyEnv()
3088 defer ResetProxyEnv()
3089 for _, tt := range proxyFromEnvTests {
3090 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3091 os.Setenv("http_proxy", tt.env)
3092 os.Setenv("https_proxy", tt.httpsenv)
3093 os.Setenv("no_proxy", tt.noenv)
3094 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3095 ResetCachedEnvironment()
3096 return ProxyFromEnvironment(req)
3097 })
3098 }
3099 }
3100
3101 func TestIdleConnChannelLeak(t *testing.T) {
3102
3103 var mu sync.Mutex
3104 var n int
3105
3106 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3107 mu.Lock()
3108 n++
3109 mu.Unlock()
3110 }))
3111 defer ts.Close()
3112
3113 const nReqs = 5
3114 didRead := make(chan bool, nReqs)
3115 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
3116 defer SetReadLoopBeforeNextReadHook(nil)
3117
3118 c := ts.Client()
3119 tr := c.Transport.(*Transport)
3120 tr.Dial = func(netw, addr string) (net.Conn, error) {
3121 return net.Dial(netw, ts.Listener.Addr().String())
3122 }
3123
3124
3125 for _, disableKeep := range []bool{true, false} {
3126 tr.DisableKeepAlives = disableKeep
3127 for i := 0; i < nReqs; i++ {
3128 _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
3129 if err != nil {
3130 t.Fatal(err)
3131 }
3132
3133
3134
3135
3136
3137 }
3138
3139
3140
3141
3142
3143
3144
3145 for i := 0; i < nReqs; i++ {
3146 <-didRead
3147 }
3148
3149 if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
3150 t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
3151 }
3152 }
3153 }
3154
3155
3156
3157
3158 func TestTransportClosesRequestBody(t *testing.T) {
3159 defer afterTest(t)
3160 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3161 io.Copy(io.Discard, r.Body)
3162 }))
3163 defer ts.Close()
3164
3165 c := ts.Client()
3166
3167 closes := 0
3168
3169 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
3170 if err != nil {
3171 t.Fatal(err)
3172 }
3173 res.Body.Close()
3174 if closes != 1 {
3175 t.Errorf("closes = %d; want 1", closes)
3176 }
3177 }
3178
3179 func TestTransportTLSHandshakeTimeout(t *testing.T) {
3180 defer afterTest(t)
3181 if testing.Short() {
3182 t.Skip("skipping in short mode")
3183 }
3184 ln := newLocalListener(t)
3185 defer ln.Close()
3186 testdonec := make(chan struct{})
3187 defer close(testdonec)
3188
3189 go func() {
3190 c, err := ln.Accept()
3191 if err != nil {
3192 t.Error(err)
3193 return
3194 }
3195 <-testdonec
3196 c.Close()
3197 }()
3198
3199 getdonec := make(chan struct{})
3200 go func() {
3201 defer close(getdonec)
3202 tr := &Transport{
3203 Dial: func(_, _ string) (net.Conn, error) {
3204 return net.Dial("tcp", ln.Addr().String())
3205 },
3206 TLSHandshakeTimeout: 250 * time.Millisecond,
3207 }
3208 cl := &Client{Transport: tr}
3209 _, err := cl.Get("https://dummy.tld/")
3210 if err == nil {
3211 t.Error("expected error")
3212 return
3213 }
3214 ue, ok := err.(*url.Error)
3215 if !ok {
3216 t.Errorf("expected url.Error; got %#v", err)
3217 return
3218 }
3219 ne, ok := ue.Err.(net.Error)
3220 if !ok {
3221 t.Errorf("expected net.Error; got %#v", err)
3222 return
3223 }
3224 if !ne.Timeout() {
3225 t.Errorf("expected timeout error; got %v", err)
3226 }
3227 if !strings.Contains(err.Error(), "handshake timeout") {
3228 t.Errorf("expected 'handshake timeout' in error; got %v", err)
3229 }
3230 }()
3231 select {
3232 case <-getdonec:
3233 case <-time.After(5 * time.Second):
3234 t.Error("test timeout; TLS handshake hung?")
3235 }
3236 }
3237
3238
3239 func TestTLSServerClosesConnection(t *testing.T) {
3240 defer afterTest(t)
3241
3242 closedc := make(chan bool, 1)
3243 ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3244 if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
3245 conn, _, _ := w.(Hijacker).Hijack()
3246 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3247 conn.Close()
3248 closedc <- true
3249 return
3250 }
3251 fmt.Fprintf(w, "hello")
3252 }))
3253 defer ts.Close()
3254
3255 c := ts.Client()
3256 tr := c.Transport.(*Transport)
3257
3258 var nSuccess = 0
3259 var errs []error
3260 const trials = 20
3261 for i := 0; i < trials; i++ {
3262 tr.CloseIdleConnections()
3263 res, err := c.Get(ts.URL + "/keep-alive-then-die")
3264 if err != nil {
3265 t.Fatal(err)
3266 }
3267 <-closedc
3268 slurp, err := io.ReadAll(res.Body)
3269 if err != nil {
3270 t.Fatal(err)
3271 }
3272 if string(slurp) != "foo" {
3273 t.Errorf("Got %q, want foo", slurp)
3274 }
3275
3276
3277
3278 res, err = c.Get(ts.URL + "/")
3279 if err != nil {
3280 errs = append(errs, err)
3281 continue
3282 }
3283 slurp, err = io.ReadAll(res.Body)
3284 if err != nil {
3285 errs = append(errs, err)
3286 continue
3287 }
3288 nSuccess++
3289 }
3290 if nSuccess > 0 {
3291 t.Logf("successes = %d of %d", nSuccess, trials)
3292 } else {
3293 t.Errorf("All runs failed:")
3294 }
3295 for _, err := range errs {
3296 t.Logf(" err: %v", err)
3297 }
3298 }
3299
3300
3301
3302
3303 type byteFromChanReader chan byte
3304
3305 func (c byteFromChanReader) Read(p []byte) (n int, err error) {
3306 if len(p) == 0 {
3307 return
3308 }
3309 b, ok := <-c
3310 if !ok {
3311 return 0, io.EOF
3312 }
3313 p[0] = b
3314 return 1, nil
3315 }
3316
3317
3318
3319
3320
3321
3322
3323 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
3324 setParallel(t)
3325 defer afterTest(t)
3326 var sconn struct {
3327 sync.Mutex
3328 c net.Conn
3329 }
3330 var getOkay bool
3331 closeConn := func() {
3332 sconn.Lock()
3333 defer sconn.Unlock()
3334 if sconn.c != nil {
3335 sconn.c.Close()
3336 sconn.c = nil
3337 if !getOkay {
3338 t.Logf("Closed server connection")
3339 }
3340 }
3341 }
3342 defer closeConn()
3343
3344 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3345 if r.Method == "GET" {
3346 io.WriteString(w, "bar")
3347 return
3348 }
3349 conn, _, _ := w.(Hijacker).Hijack()
3350 sconn.Lock()
3351 sconn.c = conn
3352 sconn.Unlock()
3353 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3354 go io.Copy(io.Discard, conn)
3355 }))
3356 defer ts.Close()
3357 c := ts.Client()
3358
3359 const bodySize = 256 << 10
3360 finalBit := make(byteFromChanReader, 1)
3361 req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
3362 req.ContentLength = bodySize
3363 res, err := c.Do(req)
3364 if err := wantBody(res, err, "foo"); err != nil {
3365 t.Errorf("POST response: %v", err)
3366 }
3367 donec := make(chan bool)
3368 go func() {
3369 defer close(donec)
3370 res, err = c.Get(ts.URL)
3371 if err := wantBody(res, err, "bar"); err != nil {
3372 t.Errorf("GET response: %v", err)
3373 return
3374 }
3375 getOkay = true
3376 }()
3377 time.AfterFunc(5*time.Second, closeConn)
3378 select {
3379 case <-donec:
3380 finalBit <- 'x'
3381 close(finalBit)
3382 case <-time.After(7 * time.Second):
3383 t.Fatal("timeout waiting for GET request to finish")
3384 }
3385 }
3386
3387
3388
3389 func TestTransportIssue10457(t *testing.T) {
3390 defer afterTest(t)
3391 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3392
3393
3394
3395
3396
3397 conn, _, _ := w.(Hijacker).Hijack()
3398 conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n"))
3399 conn.Close()
3400 }))
3401 defer ts.Close()
3402 c := ts.Client()
3403
3404 res, err := c.Get(ts.URL)
3405 if err != nil {
3406 t.Fatalf("Get: %v", err)
3407 }
3408 defer res.Body.Close()
3409
3410
3411
3412
3413 if got, want := res.Header.Get("Foo"), "Bar"; got != want {
3414 t.Errorf("Foo header = %q; want %q", got, want)
3415 }
3416 }
3417
3418 type closerFunc func() error
3419
3420 func (f closerFunc) Close() error { return f() }
3421
3422 type writerFuncConn struct {
3423 net.Conn
3424 write func(p []byte) (n int, err error)
3425 }
3426
3427 func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
3428
3429
3430
3431
3432
3433
3434
3435
3436
3437
3438
3439
3440 func TestRetryRequestsOnError(t *testing.T) {
3441 newRequest := func(method, urlStr string, body io.Reader) *Request {
3442 req, err := NewRequest(method, urlStr, body)
3443 if err != nil {
3444 t.Fatal(err)
3445 }
3446 return req
3447 }
3448
3449 testCases := []struct {
3450 name string
3451 failureN int
3452 failureErr error
3453
3454
3455
3456 req func() *Request
3457 reqString string
3458 }{
3459 {
3460 name: "IdempotentNoBodySomeWritten",
3461
3462
3463 failureN: 1,
3464
3465 failureErr: ExportErrServerClosedIdle,
3466 req: func() *Request {
3467 return newRequest("GET", "http://fake.golang", nil)
3468 },
3469 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3470 },
3471 {
3472 name: "IdempotentGetBodySomeWritten",
3473
3474
3475 failureN: 1,
3476
3477 failureErr: ExportErrServerClosedIdle,
3478 req: func() *Request {
3479 return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
3480 },
3481 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3482 },
3483 {
3484 name: "NothingWrittenNoBody",
3485
3486
3487 failureN: 0,
3488 failureErr: errors.New("second write fails"),
3489 req: func() *Request {
3490 return newRequest("DELETE", "http://fake.golang", nil)
3491 },
3492 reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3493 },
3494 {
3495 name: "NothingWrittenGetBody",
3496
3497
3498 failureN: 0,
3499 failureErr: errors.New("second write fails"),
3500
3501
3502 req: func() *Request {
3503 return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
3504 },
3505 reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3506 },
3507 }
3508
3509 for _, tc := range testCases {
3510 t.Run(tc.name, func(t *testing.T) {
3511 defer afterTest(t)
3512
3513 var (
3514 mu sync.Mutex
3515 logbuf bytes.Buffer
3516 )
3517 logf := func(format string, args ...any) {
3518 mu.Lock()
3519 defer mu.Unlock()
3520 fmt.Fprintf(&logbuf, format, args...)
3521 logbuf.WriteByte('\n')
3522 }
3523
3524 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3525 logf("Handler")
3526 w.Header().Set("X-Status", "ok")
3527 }))
3528 defer ts.Close()
3529
3530 var writeNumAtomic int32
3531 c := ts.Client()
3532 c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
3533 logf("Dial")
3534 c, err := net.Dial(network, ts.Listener.Addr().String())
3535 if err != nil {
3536 logf("Dial error: %v", err)
3537 return nil, err
3538 }
3539 return &writerFuncConn{
3540 Conn: c,
3541 write: func(p []byte) (n int, err error) {
3542 if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
3543 logf("intentional write failure")
3544 return tc.failureN, tc.failureErr
3545 }
3546 logf("Write(%q)", p)
3547 return c.Write(p)
3548 },
3549 }, nil
3550 }
3551
3552 SetRoundTripRetried(func() {
3553 logf("Retried.")
3554 })
3555 defer SetRoundTripRetried(nil)
3556
3557 for i := 0; i < 3; i++ {
3558 t0 := time.Now()
3559 req := tc.req()
3560 res, err := c.Do(req)
3561 if err != nil {
3562 if time.Since(t0) < MaxWriteWaitBeforeConnReuse/2 {
3563 mu.Lock()
3564 got := logbuf.String()
3565 mu.Unlock()
3566 t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
3567 }
3568 t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", MaxWriteWaitBeforeConnReuse)
3569 }
3570 res.Body.Close()
3571 if res.Request != req {
3572 t.Errorf("Response.Request != original request; want identical Request")
3573 }
3574 }
3575
3576 mu.Lock()
3577 got := logbuf.String()
3578 mu.Unlock()
3579 want := fmt.Sprintf(`Dial
3580 Write("%s")
3581 Handler
3582 intentional write failure
3583 Retried.
3584 Dial
3585 Write("%s")
3586 Handler
3587 Write("%s")
3588 Handler
3589 `, tc.reqString, tc.reqString, tc.reqString)
3590 if got != want {
3591 t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
3592 }
3593 })
3594 }
3595 }
3596
3597
3598 func TestTransportClosesBodyOnError(t *testing.T) {
3599 setParallel(t)
3600 defer afterTest(t)
3601 readBody := make(chan error, 1)
3602 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3603 _, err := io.ReadAll(r.Body)
3604 readBody <- err
3605 }))
3606 defer ts.Close()
3607 c := ts.Client()
3608 fakeErr := errors.New("fake error")
3609 didClose := make(chan bool, 1)
3610 req, _ := NewRequest("POST", ts.URL, struct {
3611 io.Reader
3612 io.Closer
3613 }{
3614 io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
3615 closerFunc(func() error {
3616 select {
3617 case didClose <- true:
3618 default:
3619 }
3620 return nil
3621 }),
3622 })
3623 res, err := c.Do(req)
3624 if res != nil {
3625 defer res.Body.Close()
3626 }
3627 if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
3628 t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
3629 }
3630 select {
3631 case err := <-readBody:
3632 if err == nil {
3633 t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
3634 }
3635 case <-time.After(5 * time.Second):
3636 t.Error("timeout waiting for server handler to complete")
3637 }
3638 select {
3639 case <-didClose:
3640 default:
3641 t.Errorf("didn't see Body.Close")
3642 }
3643 }
3644
3645 func TestTransportDialTLS(t *testing.T) {
3646 setParallel(t)
3647 defer afterTest(t)
3648 var mu sync.Mutex
3649 var gotReq, didDial bool
3650
3651 ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3652 mu.Lock()
3653 gotReq = true
3654 mu.Unlock()
3655 }))
3656 defer ts.Close()
3657 c := ts.Client()
3658 c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
3659 mu.Lock()
3660 didDial = true
3661 mu.Unlock()
3662 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
3663 if err != nil {
3664 return nil, err
3665 }
3666 return c, c.Handshake()
3667 }
3668
3669 res, err := c.Get(ts.URL)
3670 if err != nil {
3671 t.Fatal(err)
3672 }
3673 res.Body.Close()
3674 mu.Lock()
3675 if !gotReq {
3676 t.Error("didn't get request")
3677 }
3678 if !didDial {
3679 t.Error("didn't use dial hook")
3680 }
3681 }
3682
3683 func TestTransportDialContext(t *testing.T) {
3684 setParallel(t)
3685 defer afterTest(t)
3686 var mu sync.Mutex
3687 var gotReq bool
3688 var receivedContext context.Context
3689
3690 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3691 mu.Lock()
3692 gotReq = true
3693 mu.Unlock()
3694 }))
3695 defer ts.Close()
3696 c := ts.Client()
3697 c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
3698 mu.Lock()
3699 receivedContext = ctx
3700 mu.Unlock()
3701 return net.Dial(netw, addr)
3702 }
3703
3704 req, err := NewRequest("GET", ts.URL, nil)
3705 if err != nil {
3706 t.Fatal(err)
3707 }
3708 ctx := context.WithValue(context.Background(), "some-key", "some-value")
3709 res, err := c.Do(req.WithContext(ctx))
3710 if err != nil {
3711 t.Fatal(err)
3712 }
3713 res.Body.Close()
3714 mu.Lock()
3715 if !gotReq {
3716 t.Error("didn't get request")
3717 }
3718 if receivedContext != ctx {
3719 t.Error("didn't receive correct context")
3720 }
3721 }
3722
3723 func TestTransportDialTLSContext(t *testing.T) {
3724 setParallel(t)
3725 defer afterTest(t)
3726 var mu sync.Mutex
3727 var gotReq bool
3728 var receivedContext context.Context
3729
3730 ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3731 mu.Lock()
3732 gotReq = true
3733 mu.Unlock()
3734 }))
3735 defer ts.Close()
3736 c := ts.Client()
3737 c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
3738 mu.Lock()
3739 receivedContext = ctx
3740 mu.Unlock()
3741 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
3742 if err != nil {
3743 return nil, err
3744 }
3745 return c, c.HandshakeContext(ctx)
3746 }
3747
3748 req, err := NewRequest("GET", ts.URL, nil)
3749 if err != nil {
3750 t.Fatal(err)
3751 }
3752 ctx := context.WithValue(context.Background(), "some-key", "some-value")
3753 res, err := c.Do(req.WithContext(ctx))
3754 if err != nil {
3755 t.Fatal(err)
3756 }
3757 res.Body.Close()
3758 mu.Lock()
3759 if !gotReq {
3760 t.Error("didn't get request")
3761 }
3762 if receivedContext != ctx {
3763 t.Error("didn't receive correct context")
3764 }
3765 }
3766
3767
3768
3769 func TestRoundTripReturnsProxyError(t *testing.T) {
3770 badProxy := func(*Request) (*url.URL, error) {
3771 return nil, errors.New("errorMessage")
3772 }
3773
3774 tr := &Transport{Proxy: badProxy}
3775
3776 req, _ := NewRequest("GET", "http://example.com", nil)
3777
3778 _, err := tr.RoundTrip(req)
3779
3780 if err == nil {
3781 t.Error("Expected proxy error to be returned by RoundTrip")
3782 }
3783 }
3784
3785
3786 func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
3787 tr := &Transport{}
3788 wantIdle := func(when string, n int) bool {
3789 got := tr.IdleConnCountForTesting("http", "example.com")
3790 if got == n {
3791 return true
3792 }
3793 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
3794 return false
3795 }
3796 wantIdle("start", 0)
3797 if !tr.PutIdleTestConn("http", "example.com") {
3798 t.Fatal("put failed")
3799 }
3800 if !tr.PutIdleTestConn("http", "example.com") {
3801 t.Fatal("second put failed")
3802 }
3803 wantIdle("after put", 2)
3804 tr.CloseIdleConnections()
3805 if !tr.IsIdleForTesting() {
3806 t.Error("should be idle after CloseIdleConnections")
3807 }
3808 wantIdle("after close idle", 0)
3809 if tr.PutIdleTestConn("http", "example.com") {
3810 t.Fatal("put didn't fail")
3811 }
3812 wantIdle("after second put", 0)
3813
3814 tr.QueueForIdleConnForTesting()
3815 if tr.IsIdleForTesting() {
3816 t.Error("shouldn't be idle after QueueForIdleConnForTesting")
3817 }
3818 if !tr.PutIdleTestConn("http", "example.com") {
3819 t.Fatal("after re-activation")
3820 }
3821 wantIdle("after final put", 1)
3822 }
3823
3824
3825
3826 func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
3827 tr := &Transport{}
3828 wantIdle := func(when string, n int) bool {
3829 got := tr.IdleConnCountForTesting("https", "example.com:443")
3830 if got == n {
3831 return true
3832 }
3833 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
3834 return false
3835 }
3836 wantIdle("start", 0)
3837 alt := funcRoundTripper(func() {})
3838 if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
3839 t.Fatal("put failed")
3840 }
3841 wantIdle("after put", 1)
3842 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3843 GotConn: func(httptrace.GotConnInfo) {
3844
3845 t.Error("GotConn called")
3846 },
3847 })
3848 req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
3849 _, err := tr.RoundTrip(req)
3850 if err != errFakeRoundTrip {
3851 t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
3852 }
3853 wantIdle("after round trip", 1)
3854 }
3855
3856 func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) {
3857 if testing.Short() {
3858 t.Skip("skipping in short mode")
3859 }
3860
3861 trFunc := func(tr *Transport) {
3862 tr.MaxConnsPerHost = 1
3863 tr.MaxIdleConnsPerHost = 1
3864 tr.IdleConnTimeout = 10 * time.Millisecond
3865 }
3866 cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc)
3867 defer cst.close()
3868
3869 if _, err := cst.c.Get(cst.ts.URL); err != nil {
3870 t.Fatalf("got error: %s", err)
3871 }
3872
3873 time.Sleep(100 * time.Millisecond)
3874 got := make(chan error)
3875 go func() {
3876 if _, err := cst.c.Get(cst.ts.URL); err != nil {
3877 got <- err
3878 }
3879 close(got)
3880 }()
3881
3882 timeout := time.NewTimer(5 * time.Second)
3883 defer timeout.Stop()
3884 select {
3885 case err := <-got:
3886 if err != nil {
3887 t.Fatalf("got error: %s", err)
3888 }
3889 case <-timeout.C:
3890 t.Fatal("request never completed")
3891 }
3892 }
3893
3894
3895
3896
3897
3898 func TestTransportRangeAndGzip(t *testing.T) {
3899 defer afterTest(t)
3900 reqc := make(chan *Request, 1)
3901 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3902 reqc <- r
3903 }))
3904 defer ts.Close()
3905 c := ts.Client()
3906
3907 req, _ := NewRequest("GET", ts.URL, nil)
3908 req.Header.Set("Range", "bytes=7-11")
3909 res, err := c.Do(req)
3910 if err != nil {
3911 t.Fatal(err)
3912 }
3913
3914 select {
3915 case r := <-reqc:
3916 if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
3917 t.Error("Transport advertised gzip support in the Accept header")
3918 }
3919 if r.Header.Get("Range") == "" {
3920 t.Error("no Range in request")
3921 }
3922 case <-time.After(10 * time.Second):
3923 t.Fatal("timeout")
3924 }
3925 res.Body.Close()
3926 }
3927
3928
3929 func TestTransportResponseCancelRace(t *testing.T) {
3930 defer afterTest(t)
3931
3932 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3933
3934 var b [1024]byte
3935 w.Write(b[:])
3936 }))
3937 defer ts.Close()
3938 tr := ts.Client().Transport.(*Transport)
3939
3940 req, err := NewRequest("GET", ts.URL, nil)
3941 if err != nil {
3942 t.Fatal(err)
3943 }
3944 res, err := tr.RoundTrip(req)
3945 if err != nil {
3946 t.Fatal(err)
3947 }
3948
3949
3950
3951 if _, err := io.Copy(io.Discard, res.Body); err != nil {
3952 t.Fatal(err)
3953 }
3954
3955 req2, err := NewRequest("GET", ts.URL, nil)
3956 if err != nil {
3957 t.Fatal(err)
3958 }
3959 tr.CancelRequest(req)
3960 res, err = tr.RoundTrip(req2)
3961 if err != nil {
3962 t.Fatal(err)
3963 }
3964 res.Body.Close()
3965 }
3966
3967
3968 func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
3969 setParallel(t)
3970 defer afterTest(t)
3971 for _, ce := range []string{"gzip", "GZIP"} {
3972 ce := ce
3973 t.Run(ce, func(t *testing.T) {
3974 const encodedString = "Hello Gopher"
3975 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3976 w.Header().Set("Content-Encoding", ce)
3977 gz := gzip.NewWriter(w)
3978 gz.Write([]byte(encodedString))
3979 gz.Close()
3980 }))
3981 defer ts.Close()
3982
3983 res, err := ts.Client().Get(ts.URL)
3984 if err != nil {
3985 t.Fatal(err)
3986 }
3987
3988 body, err := io.ReadAll(res.Body)
3989 res.Body.Close()
3990 if err != nil {
3991 t.Fatal(err)
3992 }
3993
3994 if string(body) != encodedString {
3995 t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
3996 }
3997 })
3998 }
3999 }
4000
4001 func TestTransportDialCancelRace(t *testing.T) {
4002 defer afterTest(t)
4003
4004 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
4005 defer ts.Close()
4006 tr := ts.Client().Transport.(*Transport)
4007
4008 req, err := NewRequest("GET", ts.URL, nil)
4009 if err != nil {
4010 t.Fatal(err)
4011 }
4012 SetEnterRoundTripHook(func() {
4013 tr.CancelRequest(req)
4014 })
4015 defer SetEnterRoundTripHook(nil)
4016 res, err := tr.RoundTrip(req)
4017 if err != ExportErrRequestCanceled {
4018 t.Errorf("expected canceled request error; got %v", err)
4019 if err == nil {
4020 res.Body.Close()
4021 }
4022 }
4023 }
4024
4025
4026
4027
4028 type logWritesConn struct {
4029 net.Conn
4030
4031 w io.Writer
4032
4033 rch <-chan io.Reader
4034 r io.Reader
4035
4036 mu sync.Mutex
4037 writes []string
4038 }
4039
4040 func (c *logWritesConn) Write(p []byte) (n int, err error) {
4041 c.mu.Lock()
4042 defer c.mu.Unlock()
4043 c.writes = append(c.writes, string(p))
4044 return c.w.Write(p)
4045 }
4046
4047 func (c *logWritesConn) Read(p []byte) (n int, err error) {
4048 if c.r == nil {
4049 c.r = <-c.rch
4050 }
4051 return c.r.Read(p)
4052 }
4053
4054 func (c *logWritesConn) Close() error { return nil }
4055
4056
4057 func TestTransportFlushesBodyChunks(t *testing.T) {
4058 defer afterTest(t)
4059 resBody := make(chan io.Reader, 1)
4060 connr, connw := io.Pipe()
4061 lw := &logWritesConn{
4062 rch: resBody,
4063 w: connw,
4064 }
4065 tr := &Transport{
4066 Dial: func(network, addr string) (net.Conn, error) {
4067 return lw, nil
4068 },
4069 }
4070 bodyr, bodyw := io.Pipe()
4071 go func() {
4072 defer bodyw.Close()
4073 for i := 0; i < 3; i++ {
4074 fmt.Fprintf(bodyw, "num%d\n", i)
4075 }
4076 }()
4077 resc := make(chan *Response)
4078 go func() {
4079 req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
4080 req.Header.Set("User-Agent", "x")
4081 res, err := tr.RoundTrip(req)
4082 if err != nil {
4083 t.Errorf("RoundTrip: %v", err)
4084 close(resc)
4085 return
4086 }
4087 resc <- res
4088
4089 }()
4090
4091 req, err := ReadRequest(bufio.NewReader(connr))
4092 if err != nil {
4093 t.Fatal(err)
4094 }
4095 io.Copy(io.Discard, req.Body)
4096
4097
4098 resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
4099 res, ok := <-resc
4100 if !ok {
4101 return
4102 }
4103 defer res.Body.Close()
4104
4105 want := []string{
4106 "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
4107 "5\r\nnum0\n\r\n",
4108 "5\r\nnum1\n\r\n",
4109 "5\r\nnum2\n\r\n",
4110 "0\r\n\r\n",
4111 }
4112 if !reflect.DeepEqual(lw.writes, want) {
4113 t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
4114 }
4115 }
4116
4117
4118 func TestTransportFlushesRequestHeader(t *testing.T) {
4119 defer afterTest(t)
4120 gotReq := make(chan struct{})
4121 cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4122 close(gotReq)
4123 }))
4124 defer cst.close()
4125
4126 pr, pw := io.Pipe()
4127 req, err := NewRequest("POST", cst.ts.URL, pr)
4128 if err != nil {
4129 t.Fatal(err)
4130 }
4131 gotRes := make(chan struct{})
4132 go func() {
4133 defer close(gotRes)
4134 res, err := cst.tr.RoundTrip(req)
4135 if err != nil {
4136 t.Error(err)
4137 return
4138 }
4139 res.Body.Close()
4140 }()
4141
4142 select {
4143 case <-gotReq:
4144 pw.Close()
4145 case <-time.After(5 * time.Second):
4146 t.Fatal("timeout waiting for handler to get request")
4147 }
4148 <-gotRes
4149 }
4150
4151
4152 func TestTransportPrefersResponseOverWriteError(t *testing.T) {
4153 if testing.Short() {
4154 t.Skip("skipping in short mode")
4155 }
4156 defer afterTest(t)
4157 const contentLengthLimit = 1024 * 1024
4158 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4159 if r.ContentLength >= contentLengthLimit {
4160 w.WriteHeader(StatusBadRequest)
4161 r.Body.Close()
4162 return
4163 }
4164 w.WriteHeader(StatusOK)
4165 }))
4166 defer ts.Close()
4167 c := ts.Client()
4168
4169 fail := 0
4170 count := 100
4171 bigBody := strings.Repeat("a", contentLengthLimit*2)
4172 for i := 0; i < count; i++ {
4173 req, err := NewRequest("PUT", ts.URL, strings.NewReader(bigBody))
4174 if err != nil {
4175 t.Fatal(err)
4176 }
4177 resp, err := c.Do(req)
4178 if err != nil {
4179 fail++
4180 t.Logf("%d = %#v", i, err)
4181 if ue, ok := err.(*url.Error); ok {
4182 t.Logf("urlErr = %#v", ue.Err)
4183 if ne, ok := ue.Err.(*net.OpError); ok {
4184 t.Logf("netOpError = %#v", ne.Err)
4185 }
4186 }
4187 } else {
4188 resp.Body.Close()
4189 if resp.StatusCode != 400 {
4190 t.Errorf("Expected status code 400, got %v", resp.Status)
4191 }
4192 }
4193 }
4194 if fail > 0 {
4195 t.Errorf("Failed %v out of %v\n", fail, count)
4196 }
4197 }
4198
4199 func TestTransportAutomaticHTTP2(t *testing.T) {
4200 testTransportAutoHTTP(t, &Transport{}, true)
4201 }
4202
4203 func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
4204 testTransportAutoHTTP(t, &Transport{
4205 ForceAttemptHTTP2: true,
4206 TLSClientConfig: new(tls.Config),
4207 }, true)
4208 }
4209
4210
4211 func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
4212 testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
4213 }
4214
4215 func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
4216 testTransportAutoHTTP(t, &Transport{
4217 TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
4218 }, false)
4219 }
4220
4221 func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
4222 testTransportAutoHTTP(t, &Transport{
4223 TLSClientConfig: new(tls.Config),
4224 }, false)
4225 }
4226
4227 func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
4228 testTransportAutoHTTP(t, &Transport{
4229 ExpectContinueTimeout: 1 * time.Second,
4230 }, true)
4231 }
4232
4233 func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
4234 var d net.Dialer
4235 testTransportAutoHTTP(t, &Transport{
4236 Dial: d.Dial,
4237 }, false)
4238 }
4239
4240 func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
4241 var d net.Dialer
4242 testTransportAutoHTTP(t, &Transport{
4243 DialContext: d.DialContext,
4244 }, false)
4245 }
4246
4247 func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
4248 testTransportAutoHTTP(t, &Transport{
4249 DialTLS: func(network, addr string) (net.Conn, error) {
4250 panic("unused")
4251 },
4252 }, false)
4253 }
4254
4255 func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
4256 CondSkipHTTP2(t)
4257 _, err := tr.RoundTrip(new(Request))
4258 if err == nil {
4259 t.Error("expected error from RoundTrip")
4260 }
4261 if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
4262 t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
4263 }
4264 }
4265
4266
4267
4268
4269
4270
4271
4272
4273 func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
4274 defer afterTest(t)
4275 cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4276 w.Header().Set("X-Addr", r.RemoteAddr)
4277
4278 }))
4279 defer cst.close()
4280 n := 100
4281 if testing.Short() {
4282 n = 10
4283 }
4284 var firstAddr string
4285 for i := 0; i < n; i++ {
4286 res, err := cst.c.Get(cst.ts.URL)
4287 if err != nil {
4288 log.Fatal(err)
4289 }
4290 addr := res.Header.Get("X-Addr")
4291 if i == 0 {
4292 firstAddr = addr
4293 } else if addr != firstAddr {
4294 t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
4295 }
4296 res.Body.Close()
4297 }
4298 }
4299
4300
4301 func TestNoCrashReturningTransportAltConn(t *testing.T) {
4302 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
4303 if err != nil {
4304 t.Fatal(err)
4305 }
4306 ln := newLocalListener(t)
4307 defer ln.Close()
4308
4309 var wg sync.WaitGroup
4310 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
4311 defer SetPendingDialHooks(nil, nil)
4312
4313 testDone := make(chan struct{})
4314 defer close(testDone)
4315 go func() {
4316 tln := tls.NewListener(ln, &tls.Config{
4317 NextProtos: []string{"foo"},
4318 Certificates: []tls.Certificate{cert},
4319 })
4320 sc, err := tln.Accept()
4321 if err != nil {
4322 t.Error(err)
4323 return
4324 }
4325 if err := sc.(*tls.Conn).Handshake(); err != nil {
4326 t.Error(err)
4327 return
4328 }
4329 <-testDone
4330 sc.Close()
4331 }()
4332
4333 addr := ln.Addr().String()
4334
4335 req, _ := NewRequest("GET", "https://fake.tld/", nil)
4336 cancel := make(chan struct{})
4337 req.Cancel = cancel
4338
4339 doReturned := make(chan bool, 1)
4340 madeRoundTripper := make(chan bool, 1)
4341
4342 tr := &Transport{
4343 DisableKeepAlives: true,
4344 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
4345 "foo": func(authority string, c *tls.Conn) RoundTripper {
4346 madeRoundTripper <- true
4347 return funcRoundTripper(func() {
4348 t.Error("foo RoundTripper should not be called")
4349 })
4350 },
4351 },
4352 Dial: func(_, _ string) (net.Conn, error) {
4353 panic("shouldn't be called")
4354 },
4355 DialTLS: func(_, _ string) (net.Conn, error) {
4356 tc, err := tls.Dial("tcp", addr, &tls.Config{
4357 InsecureSkipVerify: true,
4358 NextProtos: []string{"foo"},
4359 })
4360 if err != nil {
4361 return nil, err
4362 }
4363 if err := tc.Handshake(); err != nil {
4364 return nil, err
4365 }
4366 close(cancel)
4367 <-doReturned
4368 return tc, nil
4369 },
4370 }
4371 c := &Client{Transport: tr}
4372
4373 _, err = c.Do(req)
4374 if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
4375 t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
4376 }
4377
4378 doReturned <- true
4379 <-madeRoundTripper
4380 wg.Wait()
4381 }
4382
4383 func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
4384 testTransportReuseConnection_Gzip(t, true)
4385 }
4386
4387 func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
4388 testTransportReuseConnection_Gzip(t, false)
4389 }
4390
4391
4392 func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) {
4393 setParallel(t)
4394 defer afterTest(t)
4395 addr := make(chan string, 2)
4396 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4397 addr <- r.RemoteAddr
4398 w.Header().Set("Content-Encoding", "gzip")
4399 if chunked {
4400 w.(Flusher).Flush()
4401 }
4402 w.Write(rgz)
4403 }))
4404 defer ts.Close()
4405 c := ts.Client()
4406
4407 for i := 0; i < 2; i++ {
4408 res, err := c.Get(ts.URL)
4409 if err != nil {
4410 t.Fatal(err)
4411 }
4412 buf := make([]byte, len(rgz))
4413 if n, err := io.ReadFull(res.Body, buf); err != nil {
4414 t.Errorf("%d. ReadFull = %v, %v", i, n, err)
4415 }
4416
4417
4418
4419 }
4420 a1, a2 := <-addr, <-addr
4421 if a1 != a2 {
4422 t.Fatalf("didn't reuse connection")
4423 }
4424 }
4425
4426 func TestTransportResponseHeaderLength(t *testing.T) {
4427 setParallel(t)
4428 defer afterTest(t)
4429 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4430 if r.URL.Path == "/long" {
4431 w.Header().Set("Long", strings.Repeat("a", 1<<20))
4432 }
4433 }))
4434 defer ts.Close()
4435 c := ts.Client()
4436 c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
4437
4438 if res, err := c.Get(ts.URL); err != nil {
4439 t.Fatal(err)
4440 } else {
4441 res.Body.Close()
4442 }
4443
4444 res, err := c.Get(ts.URL + "/long")
4445 if err == nil {
4446 defer res.Body.Close()
4447 var n int64
4448 for k, vv := range res.Header {
4449 for _, v := range vv {
4450 n += int64(len(k)) + int64(len(v))
4451 }
4452 }
4453 t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
4454 }
4455 if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
4456 t.Errorf("got error: %v; want %q", err, want)
4457 }
4458 }
4459
4460 func TestTransportEventTrace(t *testing.T) { testTransportEventTrace(t, h1Mode, false) }
4461 func TestTransportEventTrace_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, false) }
4462
4463
4464 func TestTransportEventTrace_NoHooks(t *testing.T) { testTransportEventTrace(t, h1Mode, true) }
4465 func TestTransportEventTrace_NoHooks_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, true) }
4466
4467 func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
4468 defer afterTest(t)
4469 const resBody = "some body"
4470 gotWroteReqEvent := make(chan struct{}, 500)
4471 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
4472 if r.Method == "GET" {
4473
4474 return
4475 }
4476 if _, err := io.ReadAll(r.Body); err != nil {
4477 t.Error(err)
4478 }
4479 if !noHooks {
4480 select {
4481 case <-gotWroteReqEvent:
4482 case <-time.After(5 * time.Second):
4483 t.Error("timeout waiting for WroteRequest event")
4484 }
4485 }
4486 io.WriteString(w, resBody)
4487 }))
4488 defer cst.close()
4489
4490 cst.tr.ExpectContinueTimeout = 1 * time.Second
4491
4492 var mu sync.Mutex
4493 var buf bytes.Buffer
4494 logf := func(format string, args ...any) {
4495 mu.Lock()
4496 defer mu.Unlock()
4497 fmt.Fprintf(&buf, format, args...)
4498 buf.WriteByte('\n')
4499 }
4500
4501 addrStr := cst.ts.Listener.Addr().String()
4502 ip, port, err := net.SplitHostPort(addrStr)
4503 if err != nil {
4504 t.Fatal(err)
4505 }
4506
4507
4508 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
4509 if host != "dns-is-faked.golang" {
4510 t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
4511 return nil, nil
4512 }
4513 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
4514 })
4515
4516 body := "some body"
4517 req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
4518 req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
4519 trace := &httptrace.ClientTrace{
4520 GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
4521 GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
4522 GotFirstResponseByte: func() { logf("first response byte") },
4523 PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
4524 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
4525 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
4526 ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
4527 ConnectDone: func(network, addr string, err error) {
4528 if err != nil {
4529 t.Errorf("ConnectDone: %v", err)
4530 }
4531 logf("ConnectDone: connected to %s %s = %v", network, addr, err)
4532 },
4533 WroteHeaderField: func(key string, value []string) {
4534 logf("WroteHeaderField: %s: %v", key, value)
4535 },
4536 WroteHeaders: func() {
4537 logf("WroteHeaders")
4538 },
4539 Wait100Continue: func() { logf("Wait100Continue") },
4540 Got100Continue: func() { logf("Got100Continue") },
4541 WroteRequest: func(e httptrace.WroteRequestInfo) {
4542 logf("WroteRequest: %+v", e)
4543 gotWroteReqEvent <- struct{}{}
4544 },
4545 }
4546 if h2 {
4547 trace.TLSHandshakeStart = func() { logf("tls handshake start") }
4548 trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
4549 logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
4550 }
4551 }
4552 if noHooks {
4553
4554 *trace = httptrace.ClientTrace{}
4555 }
4556 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
4557
4558 req.Header.Set("Expect", "100-continue")
4559 res, err := cst.c.Do(req)
4560 if err != nil {
4561 t.Fatal(err)
4562 }
4563 logf("got roundtrip.response")
4564 slurp, err := io.ReadAll(res.Body)
4565 if err != nil {
4566 t.Fatal(err)
4567 }
4568 logf("consumed body")
4569 if string(slurp) != resBody || res.StatusCode != 200 {
4570 t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
4571 }
4572 res.Body.Close()
4573
4574 if noHooks {
4575
4576
4577
4578 return
4579 }
4580
4581 mu.Lock()
4582 got := buf.String()
4583 mu.Unlock()
4584
4585 wantOnce := func(sub string) {
4586 if strings.Count(got, sub) != 1 {
4587 t.Errorf("expected substring %q exactly once in output.", sub)
4588 }
4589 }
4590 wantOnceOrMore := func(sub string) {
4591 if strings.Count(got, sub) == 0 {
4592 t.Errorf("expected substring %q at least once in output.", sub)
4593 }
4594 }
4595 wantOnce("Getting conn for dns-is-faked.golang:" + port)
4596 wantOnce("DNS start: {Host:dns-is-faked.golang}")
4597 wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
4598 wantOnce("got conn: {")
4599 wantOnceOrMore("Connecting to tcp " + addrStr)
4600 wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
4601 wantOnce("Reused:false WasIdle:false IdleTime:0s")
4602 wantOnce("first response byte")
4603 if h2 {
4604 wantOnce("tls handshake start")
4605 wantOnce("tls handshake done")
4606 } else {
4607 wantOnce("PutIdleConn = <nil>")
4608 wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
4609
4610
4611 wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
4612 wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
4613 wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
4614 wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
4615 }
4616 wantOnce("WroteHeaders")
4617 wantOnce("Wait100Continue")
4618 wantOnce("Got100Continue")
4619 wantOnce("WroteRequest: {Err:<nil>}")
4620 if strings.Contains(got, " to udp ") {
4621 t.Errorf("should not see UDP (DNS) connections")
4622 }
4623 if t.Failed() {
4624 t.Errorf("Output:\n%s", got)
4625 }
4626
4627
4628 req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
4629 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
4630 res, err = cst.c.Do(req)
4631 if err != nil {
4632 t.Fatal(err)
4633 }
4634 if res.StatusCode != 200 {
4635 t.Fatal(res.Status)
4636 }
4637 res.Body.Close()
4638
4639 mu.Lock()
4640 got = buf.String()
4641 mu.Unlock()
4642
4643 sub := "Getting conn for dns-is-faked.golang:"
4644 if gotn, want := strings.Count(got, sub), 2; gotn != want {
4645 t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
4646 }
4647
4648 }
4649
4650 func TestTransportEventTraceTLSVerify(t *testing.T) {
4651 var mu sync.Mutex
4652 var buf bytes.Buffer
4653 logf := func(format string, args ...any) {
4654 mu.Lock()
4655 defer mu.Unlock()
4656 fmt.Fprintf(&buf, format, args...)
4657 buf.WriteByte('\n')
4658 }
4659
4660 ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4661 t.Error("Unexpected request")
4662 }))
4663 defer ts.Close()
4664 ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
4665 logf("%s", p)
4666 return len(p), nil
4667 }), "", 0)
4668
4669 certpool := x509.NewCertPool()
4670 certpool.AddCert(ts.Certificate())
4671
4672 c := &Client{Transport: &Transport{
4673 TLSClientConfig: &tls.Config{
4674 ServerName: "dns-is-faked.golang",
4675 RootCAs: certpool,
4676 },
4677 }}
4678
4679 trace := &httptrace.ClientTrace{
4680 TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
4681 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
4682 logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
4683 },
4684 }
4685
4686 req, _ := NewRequest("GET", ts.URL, nil)
4687 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
4688 _, err := c.Do(req)
4689 if err == nil {
4690 t.Error("Expected request to fail TLS verification")
4691 }
4692
4693 mu.Lock()
4694 got := buf.String()
4695 mu.Unlock()
4696
4697 wantOnce := func(sub string) {
4698 if strings.Count(got, sub) != 1 {
4699 t.Errorf("expected substring %q exactly once in output.", sub)
4700 }
4701 }
4702
4703 wantOnce("TLSHandshakeStart")
4704 wantOnce("TLSHandshakeDone")
4705 wantOnce("err = x509: certificate is valid for example.com")
4706
4707 if t.Failed() {
4708 t.Errorf("Output:\n%s", got)
4709 }
4710 }
4711
4712 var (
4713 isDNSHijackedOnce sync.Once
4714 isDNSHijacked bool
4715 )
4716
4717 func skipIfDNSHijacked(t *testing.T) {
4718
4719
4720
4721 isDNSHijackedOnce.Do(func() {
4722 addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
4723 isDNSHijacked = len(addrs) != 0
4724 })
4725 if isDNSHijacked {
4726 t.Skip("skipping; test requires non-hijacking DNS server")
4727 }
4728 }
4729
4730 func TestTransportEventTraceRealDNS(t *testing.T) {
4731 skipIfDNSHijacked(t)
4732 defer afterTest(t)
4733 tr := &Transport{}
4734 defer tr.CloseIdleConnections()
4735 c := &Client{Transport: tr}
4736
4737 var mu sync.Mutex
4738 var buf bytes.Buffer
4739 logf := func(format string, args ...any) {
4740 mu.Lock()
4741 defer mu.Unlock()
4742 fmt.Fprintf(&buf, format, args...)
4743 buf.WriteByte('\n')
4744 }
4745
4746 req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
4747 trace := &httptrace.ClientTrace{
4748 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
4749 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
4750 ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
4751 ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
4752 }
4753 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
4754
4755 resp, err := c.Do(req)
4756 if err == nil {
4757 resp.Body.Close()
4758 t.Fatal("expected error during DNS lookup")
4759 }
4760
4761 mu.Lock()
4762 got := buf.String()
4763 mu.Unlock()
4764
4765 wantSub := func(sub string) {
4766 if !strings.Contains(got, sub) {
4767 t.Errorf("expected substring %q in output.", sub)
4768 }
4769 }
4770 wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
4771 wantSub("DNSDone: {Addrs:[] Err:")
4772 if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
4773 t.Errorf("should not see Connect events")
4774 }
4775 if t.Failed() {
4776 t.Errorf("Output:\n%s", got)
4777 }
4778 }
4779
4780
4781 func TestTransportRejectsAlphaPort(t *testing.T) {
4782 res, err := Get("http://dummy.tld:123foo/bar")
4783 if err == nil {
4784 res.Body.Close()
4785 t.Fatal("unexpected success")
4786 }
4787 ue, ok := err.(*url.Error)
4788 if !ok {
4789 t.Fatalf("got %#v; want *url.Error", err)
4790 }
4791 got := ue.Err.Error()
4792 want := `invalid port ":123foo" after host`
4793 if got != want {
4794 t.Errorf("got error %q; want %q", got, want)
4795 }
4796 }
4797
4798
4799
4800 func TestTLSHandshakeTrace(t *testing.T) {
4801 defer afterTest(t)
4802 ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
4803 defer ts.Close()
4804
4805 var mu sync.Mutex
4806 var start, done bool
4807 trace := &httptrace.ClientTrace{
4808 TLSHandshakeStart: func() {
4809 mu.Lock()
4810 defer mu.Unlock()
4811 start = true
4812 },
4813 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
4814 mu.Lock()
4815 defer mu.Unlock()
4816 done = true
4817 if err != nil {
4818 t.Fatal("Expected error to be nil but was:", err)
4819 }
4820 },
4821 }
4822
4823 c := ts.Client()
4824 req, err := NewRequest("GET", ts.URL, nil)
4825 if err != nil {
4826 t.Fatal("Unable to construct test request:", err)
4827 }
4828 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
4829
4830 r, err := c.Do(req)
4831 if err != nil {
4832 t.Fatal("Unexpected error making request:", err)
4833 }
4834 r.Body.Close()
4835 mu.Lock()
4836 defer mu.Unlock()
4837 if !start {
4838 t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
4839 }
4840 if !done {
4841 t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't")
4842 }
4843 }
4844
4845 func TestTransportMaxIdleConns(t *testing.T) {
4846 defer afterTest(t)
4847 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4848
4849 }))
4850 defer ts.Close()
4851 c := ts.Client()
4852 tr := c.Transport.(*Transport)
4853 tr.MaxIdleConns = 4
4854
4855 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
4856 if err != nil {
4857 t.Fatal(err)
4858 }
4859 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
4860 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
4861 })
4862
4863 hitHost := func(n int) {
4864 req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
4865 req = req.WithContext(ctx)
4866 res, err := c.Do(req)
4867 if err != nil {
4868 t.Fatal(err)
4869 }
4870 res.Body.Close()
4871 }
4872 for i := 0; i < 4; i++ {
4873 hitHost(i)
4874 }
4875 want := []string{
4876 "|http|host-0.dns-is-faked.golang:" + port,
4877 "|http|host-1.dns-is-faked.golang:" + port,
4878 "|http|host-2.dns-is-faked.golang:" + port,
4879 "|http|host-3.dns-is-faked.golang:" + port,
4880 }
4881 if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
4882 t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
4883 }
4884
4885
4886 hitHost(4)
4887 want = []string{
4888 "|http|host-1.dns-is-faked.golang:" + port,
4889 "|http|host-2.dns-is-faked.golang:" + port,
4890 "|http|host-3.dns-is-faked.golang:" + port,
4891 "|http|host-4.dns-is-faked.golang:" + port,
4892 }
4893 if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
4894 t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
4895 }
4896 }
4897
4898 func TestTransportIdleConnTimeout_h1(t *testing.T) { testTransportIdleConnTimeout(t, h1Mode) }
4899 func TestTransportIdleConnTimeout_h2(t *testing.T) { testTransportIdleConnTimeout(t, h2Mode) }
4900 func testTransportIdleConnTimeout(t *testing.T, h2 bool) {
4901 if testing.Short() {
4902 t.Skip("skipping in short mode")
4903 }
4904 defer afterTest(t)
4905
4906 const timeout = 1 * time.Second
4907
4908 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
4909
4910 }))
4911 defer cst.close()
4912 tr := cst.tr
4913 tr.IdleConnTimeout = timeout
4914 defer tr.CloseIdleConnections()
4915 c := &Client{Transport: tr}
4916
4917 idleConns := func() []string {
4918 if h2 {
4919 return tr.IdleConnStrsForTesting_h2()
4920 } else {
4921 return tr.IdleConnStrsForTesting()
4922 }
4923 }
4924
4925 var conn string
4926 doReq := func(n int) {
4927 req, _ := NewRequest("GET", cst.ts.URL, nil)
4928 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
4929 PutIdleConn: func(err error) {
4930 if err != nil {
4931 t.Errorf("failed to keep idle conn: %v", err)
4932 }
4933 },
4934 }))
4935 res, err := c.Do(req)
4936 if err != nil {
4937 t.Fatal(err)
4938 }
4939 res.Body.Close()
4940 conns := idleConns()
4941 if len(conns) != 1 {
4942 t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
4943 }
4944 if conn == "" {
4945 conn = conns[0]
4946 }
4947 if conn != conns[0] {
4948 t.Fatalf("req %v: cached connection changed; expected the same one throughout the test", n)
4949 }
4950 }
4951 for i := 0; i < 3; i++ {
4952 doReq(i)
4953 time.Sleep(timeout / 2)
4954 }
4955 time.Sleep(timeout * 3 / 2)
4956 if got := idleConns(); len(got) != 0 {
4957 t.Errorf("idle conns = %q; want none", got)
4958 }
4959 }
4960
4961
4962
4963
4964
4965
4966
4967
4968
4969
4970
4971
4972 func TestIdleConnH2Crash(t *testing.T) {
4973 setParallel(t)
4974 cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4975
4976 }))
4977 defer cst.close()
4978
4979 ctx, cancel := context.WithCancel(context.Background())
4980 defer cancel()
4981
4982 sawDoErr := make(chan bool, 1)
4983 testDone := make(chan struct{})
4984 defer close(testDone)
4985
4986 cst.tr.IdleConnTimeout = 5 * time.Millisecond
4987 cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
4988 c, err := tls.Dial(network, addr, &tls.Config{
4989 InsecureSkipVerify: true,
4990 NextProtos: []string{"h2"},
4991 })
4992 if err != nil {
4993 t.Error(err)
4994 return nil, err
4995 }
4996 if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
4997 t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
4998 c.Close()
4999 return nil, errors.New("bogus")
5000 }
5001
5002 cancel()
5003
5004 failTimer := time.NewTimer(5 * time.Second)
5005 defer failTimer.Stop()
5006 select {
5007 case <-sawDoErr:
5008 case <-testDone:
5009 case <-failTimer.C:
5010 t.Error("timeout in DialTLS, waiting too long for cst.c.Do to fail")
5011 }
5012 return c, nil
5013 }
5014
5015 req, _ := NewRequest("GET", cst.ts.URL, nil)
5016 req = req.WithContext(ctx)
5017 res, err := cst.c.Do(req)
5018 if err == nil {
5019 res.Body.Close()
5020 t.Fatal("unexpected success")
5021 }
5022 sawDoErr <- true
5023
5024
5025 time.Sleep(cst.tr.IdleConnTimeout * 10)
5026 }
5027
5028 type funcConn struct {
5029 net.Conn
5030 read func([]byte) (int, error)
5031 write func([]byte) (int, error)
5032 }
5033
5034 func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
5035 func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
5036 func (c funcConn) Close() error { return nil }
5037
5038
5039
5040 func TestTransportReturnsPeekError(t *testing.T) {
5041 errValue := errors.New("specific error value")
5042
5043 wrote := make(chan struct{})
5044 var wroteOnce sync.Once
5045
5046 tr := &Transport{
5047 Dial: func(network, addr string) (net.Conn, error) {
5048 c := funcConn{
5049 read: func([]byte) (int, error) {
5050 <-wrote
5051 return 0, errValue
5052 },
5053 write: func(p []byte) (int, error) {
5054 wroteOnce.Do(func() { close(wrote) })
5055 return len(p), nil
5056 },
5057 }
5058 return c, nil
5059 },
5060 }
5061 _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
5062 if err != errValue {
5063 t.Errorf("error = %#v; want %v", err, errValue)
5064 }
5065 }
5066
5067
5068 func TestTransportIDNA_h1(t *testing.T) { testTransportIDNA(t, h1Mode) }
5069 func TestTransportIDNA_h2(t *testing.T) { testTransportIDNA(t, h2Mode) }
5070 func testTransportIDNA(t *testing.T, h2 bool) {
5071 defer afterTest(t)
5072
5073 const uniDomain = "гофер.го"
5074 const punyDomain = "xn--c1ae0ajs.xn--c1aw"
5075
5076 var port string
5077 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
5078 want := punyDomain + ":" + port
5079 if r.Host != want {
5080 t.Errorf("Host header = %q; want %q", r.Host, want)
5081 }
5082 if h2 {
5083 if r.TLS == nil {
5084 t.Errorf("r.TLS == nil")
5085 } else if r.TLS.ServerName != punyDomain {
5086 t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
5087 }
5088 }
5089 w.Header().Set("Hit-Handler", "1")
5090 }))
5091 defer cst.close()
5092
5093 ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
5094 if err != nil {
5095 t.Fatal(err)
5096 }
5097
5098
5099 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5100 if host != punyDomain {
5101 t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
5102 return nil, nil
5103 }
5104 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5105 })
5106
5107 req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
5108 trace := &httptrace.ClientTrace{
5109 GetConn: func(hostPort string) {
5110 want := net.JoinHostPort(punyDomain, port)
5111 if hostPort != want {
5112 t.Errorf("getting conn for %q; want %q", hostPort, want)
5113 }
5114 },
5115 DNSStart: func(e httptrace.DNSStartInfo) {
5116 if e.Host != punyDomain {
5117 t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
5118 }
5119 },
5120 }
5121 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5122
5123 res, err := cst.tr.RoundTrip(req)
5124 if err != nil {
5125 t.Fatal(err)
5126 }
5127 defer res.Body.Close()
5128 if res.Header.Get("Hit-Handler") != "1" {
5129 out, err := httputil.DumpResponse(res, true)
5130 if err != nil {
5131 t.Fatal(err)
5132 }
5133 t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
5134 }
5135 }
5136
5137
5138 func TestTransportProxyConnectHeader(t *testing.T) {
5139 defer afterTest(t)
5140 reqc := make(chan *Request, 1)
5141 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
5142 if r.Method != "CONNECT" {
5143 t.Errorf("method = %q; want CONNECT", r.Method)
5144 }
5145 reqc <- r
5146 c, _, err := w.(Hijacker).Hijack()
5147 if err != nil {
5148 t.Errorf("Hijack: %v", err)
5149 return
5150 }
5151 c.Close()
5152 }))
5153 defer ts.Close()
5154
5155 c := ts.Client()
5156 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5157 return url.Parse(ts.URL)
5158 }
5159 c.Transport.(*Transport).ProxyConnectHeader = Header{
5160 "User-Agent": {"foo"},
5161 "Other": {"bar"},
5162 }
5163
5164 res, err := c.Get("https://dummy.tld/")
5165 if err == nil {
5166 res.Body.Close()
5167 t.Errorf("unexpected success")
5168 }
5169 select {
5170 case <-time.After(3 * time.Second):
5171 t.Fatal("timeout")
5172 case r := <-reqc:
5173 if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
5174 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5175 }
5176 if got, want := r.Header.Get("Other"), "bar"; got != want {
5177 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5178 }
5179 }
5180 }
5181
5182 func TestTransportProxyGetConnectHeader(t *testing.T) {
5183 defer afterTest(t)
5184 reqc := make(chan *Request, 1)
5185 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
5186 if r.Method != "CONNECT" {
5187 t.Errorf("method = %q; want CONNECT", r.Method)
5188 }
5189 reqc <- r
5190 c, _, err := w.(Hijacker).Hijack()
5191 if err != nil {
5192 t.Errorf("Hijack: %v", err)
5193 return
5194 }
5195 c.Close()
5196 }))
5197 defer ts.Close()
5198
5199 c := ts.Client()
5200 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5201 return url.Parse(ts.URL)
5202 }
5203
5204 c.Transport.(*Transport).ProxyConnectHeader = Header{
5205 "User-Agent": {"foo"},
5206 "Other": {"bar"},
5207 }
5208 c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
5209 return Header{
5210 "User-Agent": {"foo2"},
5211 "Other": {"bar2"},
5212 }, nil
5213 }
5214
5215 res, err := c.Get("https://dummy.tld/")
5216 if err == nil {
5217 res.Body.Close()
5218 t.Errorf("unexpected success")
5219 }
5220 select {
5221 case <-time.After(3 * time.Second):
5222 t.Fatal("timeout")
5223 case r := <-reqc:
5224 if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
5225 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5226 }
5227 if got, want := r.Header.Get("Other"), "bar2"; got != want {
5228 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5229 }
5230 }
5231 }
5232
5233 var errFakeRoundTrip = errors.New("fake roundtrip")
5234
5235 type funcRoundTripper func()
5236
5237 func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
5238 fn()
5239 return nil, errFakeRoundTrip
5240 }
5241
5242 func wantBody(res *Response, err error, want string) error {
5243 if err != nil {
5244 return err
5245 }
5246 slurp, err := io.ReadAll(res.Body)
5247 if err != nil {
5248 return fmt.Errorf("error reading body: %v", err)
5249 }
5250 if string(slurp) != want {
5251 return fmt.Errorf("body = %q; want %q", slurp, want)
5252 }
5253 if err := res.Body.Close(); err != nil {
5254 return fmt.Errorf("body Close = %v", err)
5255 }
5256 return nil
5257 }
5258
5259 func newLocalListener(t *testing.T) net.Listener {
5260 ln, err := net.Listen("tcp", "127.0.0.1:0")
5261 if err != nil {
5262 ln, err = net.Listen("tcp6", "[::1]:0")
5263 }
5264 if err != nil {
5265 t.Fatal(err)
5266 }
5267 return ln
5268 }
5269
5270 type countCloseReader struct {
5271 n *int
5272 io.Reader
5273 }
5274
5275 func (cr countCloseReader) Close() error {
5276 (*cr.n)++
5277 return nil
5278 }
5279
5280
5281 var rgz = []byte{
5282 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
5283 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
5284 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
5285 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
5286 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
5287 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
5288 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
5289 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
5290 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
5291 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
5292 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
5293 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
5294 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
5295 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
5296 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5297 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
5298 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
5299 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
5300 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
5301 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5302 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
5303 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
5304 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
5305 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
5306 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
5307 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
5308 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
5309 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
5310 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
5311 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5312 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5313 0x00, 0x00,
5314 }
5315
5316
5317
5318 func TestMissingStatusNoPanic(t *testing.T) {
5319 t.Parallel()
5320
5321 const want = "unknown status code"
5322
5323 ln := newLocalListener(t)
5324 addr := ln.Addr().String()
5325 done := make(chan bool)
5326 fullAddrURL := fmt.Sprintf("http://%s", addr)
5327 raw := "HTTP/1.1 400\r\n" +
5328 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
5329 "Content-Type: text/html; charset=utf-8\r\n" +
5330 "Content-Length: 10\r\n" +
5331 "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
5332 "Vary: Accept-Encoding\r\n\r\n" +
5333 "Aloha Olaa"
5334
5335 go func() {
5336 defer close(done)
5337
5338 conn, _ := ln.Accept()
5339 if conn != nil {
5340 io.WriteString(conn, raw)
5341 io.ReadAll(conn)
5342 conn.Close()
5343 }
5344 }()
5345
5346 proxyURL, err := url.Parse(fullAddrURL)
5347 if err != nil {
5348 t.Fatalf("proxyURL: %v", err)
5349 }
5350
5351 tr := &Transport{Proxy: ProxyURL(proxyURL)}
5352
5353 req, _ := NewRequest("GET", "https://golang.org/", nil)
5354 res, err, panicked := doFetchCheckPanic(tr, req)
5355 if panicked {
5356 t.Error("panicked, expecting an error")
5357 }
5358 if res != nil && res.Body != nil {
5359 io.Copy(io.Discard, res.Body)
5360 res.Body.Close()
5361 }
5362
5363 if err == nil || !strings.Contains(err.Error(), want) {
5364 t.Errorf("got=%v want=%q", err, want)
5365 }
5366
5367 ln.Close()
5368 <-done
5369 }
5370
5371 func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
5372 defer func() {
5373 if r := recover(); r != nil {
5374 panicked = true
5375 }
5376 }()
5377 res, err = tr.RoundTrip(req)
5378 return
5379 }
5380
5381
5382
5383 func TestNoBodyOnChunked304Response(t *testing.T) {
5384 defer afterTest(t)
5385 cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5386 conn, buf, _ := w.(Hijacker).Hijack()
5387 buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
5388 buf.Flush()
5389 conn.Close()
5390 }))
5391 defer cst.close()
5392
5393
5394
5395
5396
5397 cst.tr.DisableKeepAlives = true
5398
5399 res, err := cst.c.Get(cst.ts.URL)
5400 if err != nil {
5401 t.Fatal(err)
5402 }
5403
5404 if res.Body != NoBody {
5405 t.Errorf("Unexpected body on 304 response")
5406 }
5407 }
5408
5409 type funcWriter func([]byte) (int, error)
5410
5411 func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
5412
5413 type doneContext struct {
5414 context.Context
5415 err error
5416 }
5417
5418 func (doneContext) Done() <-chan struct{} {
5419 c := make(chan struct{})
5420 close(c)
5421 return c
5422 }
5423
5424 func (d doneContext) Err() error { return d.err }
5425
5426
5427 func TestTransportCheckContextDoneEarly(t *testing.T) {
5428 tr := &Transport{}
5429 req, _ := NewRequest("GET", "http://fake.example/", nil)
5430 wantErr := errors.New("some error")
5431 req = req.WithContext(doneContext{context.Background(), wantErr})
5432 _, err := tr.RoundTrip(req)
5433 if err != wantErr {
5434 t.Errorf("error = %v; want %v", err, wantErr)
5435 }
5436 }
5437
5438
5439
5440
5441
5442
5443 func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
5444 setParallel(t)
5445 defer afterTest(t)
5446 inHandler := make(chan net.Conn, 1)
5447 handlerReadReturned := make(chan bool, 1)
5448 cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5449 conn, _, err := w.(Hijacker).Hijack()
5450 if err != nil {
5451 t.Error(err)
5452 return
5453 }
5454 inHandler <- conn
5455 n, err := conn.Read([]byte{0})
5456 if n != 0 || err != io.EOF {
5457 t.Errorf("unexpected Read result: %v, %v", n, err)
5458 }
5459 handlerReadReturned <- true
5460 }))
5461 defer cst.close()
5462
5463 const timeout = 50 * time.Millisecond
5464 cst.c.Timeout = timeout
5465
5466 _, err := cst.c.Get(cst.ts.URL)
5467 if err == nil {
5468 t.Fatal("unexpected Get succeess")
5469 }
5470
5471 select {
5472 case c := <-inHandler:
5473 select {
5474 case <-handlerReadReturned:
5475
5476 return
5477 case <-time.After(5 * time.Second):
5478 t.Error("Handler's conn.Read seems to be stuck in Read")
5479 c.Close()
5480 }
5481 case <-time.After(timeout * 10):
5482
5483
5484
5485
5486 t.Skip("skipping test on slow builder")
5487 }
5488 }
5489
5490
5491
5492
5493
5494
5495 func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
5496 setParallel(t)
5497 defer afterTest(t)
5498 inHandler := make(chan net.Conn, 1)
5499 handlerResult := make(chan error, 1)
5500 cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5501 w.Header().Set("Content-Length", "100")
5502 w.(Flusher).Flush()
5503 conn, _, err := w.(Hijacker).Hijack()
5504 if err != nil {
5505 t.Error(err)
5506 return
5507 }
5508 conn.Write([]byte("foo"))
5509 inHandler <- conn
5510 n, err := conn.Read([]byte{0})
5511
5512
5513
5514
5515
5516 if n != 0 || err == nil {
5517 handlerResult <- fmt.Errorf("unexpected Read result: %v, %v", n, err)
5518 return
5519 }
5520 handlerResult <- nil
5521 }))
5522 defer cst.close()
5523
5524
5525
5526
5527
5528 cst.c.Timeout = time.Minute
5529 req, _ := NewRequest("GET", cst.ts.URL, nil)
5530 cancel := make(chan struct{})
5531 req.Cancel = cancel
5532
5533 res, err := cst.c.Do(req)
5534 if err != nil {
5535 select {
5536 case <-inHandler:
5537 t.Fatalf("Get error: %v", err)
5538 default:
5539
5540 t.Skip("skipping test on slow builder")
5541 }
5542 }
5543
5544 close(cancel)
5545 got, err := io.ReadAll(res.Body)
5546 if err == nil {
5547 t.Fatalf("unexpected success; read %q, nil", got)
5548 }
5549
5550 select {
5551 case c := <-inHandler:
5552 select {
5553 case err := <-handlerResult:
5554 if err != nil {
5555 t.Errorf("handler: %v", err)
5556 }
5557 return
5558 case <-time.After(5 * time.Second):
5559 t.Error("Handler's conn.Read seems to be stuck in Read")
5560 c.Close()
5561 }
5562 case <-time.After(5 * time.Second):
5563 t.Fatal("timeout")
5564 }
5565 }
5566
5567 func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
5568 setParallel(t)
5569 defer afterTest(t)
5570 done := make(chan struct{})
5571 defer close(done)
5572 cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5573 conn, _, err := w.(Hijacker).Hijack()
5574 if err != nil {
5575 t.Error(err)
5576 return
5577 }
5578 defer conn.Close()
5579 io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
5580 bs := bufio.NewScanner(conn)
5581 bs.Scan()
5582 fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
5583 <-done
5584 }))
5585 defer cst.close()
5586
5587 req, _ := NewRequest("GET", cst.ts.URL, nil)
5588 req.Header.Set("Upgrade", "foo")
5589 req.Header.Set("Connection", "upgrade")
5590 res, err := cst.c.Do(req)
5591 if err != nil {
5592 t.Fatal(err)
5593 }
5594 if res.StatusCode != 101 {
5595 t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
5596 }
5597 rwc, ok := res.Body.(io.ReadWriteCloser)
5598 if !ok {
5599 t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
5600 }
5601 defer rwc.Close()
5602 bs := bufio.NewScanner(rwc)
5603 if !bs.Scan() {
5604 t.Fatalf("expected readable input")
5605 }
5606 if got, want := bs.Text(), "Some buffered data"; got != want {
5607 t.Errorf("read %q; want %q", got, want)
5608 }
5609 io.WriteString(rwc, "echo\n")
5610 if !bs.Scan() {
5611 t.Fatalf("expected another line")
5612 }
5613 if got, want := bs.Text(), "ECHO"; got != want {
5614 t.Errorf("read %q; want %q", got, want)
5615 }
5616 }
5617
5618 func TestTransportCONNECTBidi(t *testing.T) {
5619 defer afterTest(t)
5620 const target = "backend:443"
5621 cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5622 if r.Method != "CONNECT" {
5623 t.Errorf("unexpected method %q", r.Method)
5624 w.WriteHeader(500)
5625 return
5626 }
5627 if r.RequestURI != target {
5628 t.Errorf("unexpected CONNECT target %q", r.RequestURI)
5629 w.WriteHeader(500)
5630 return
5631 }
5632 nc, brw, err := w.(Hijacker).Hijack()
5633 if err != nil {
5634 t.Error(err)
5635 return
5636 }
5637 defer nc.Close()
5638 nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
5639
5640 for {
5641 line, err := brw.ReadString('\n')
5642 if err != nil {
5643 if err != io.EOF {
5644 t.Error(err)
5645 }
5646 return
5647 }
5648 io.WriteString(brw, strings.ToUpper(line))
5649 brw.Flush()
5650 }
5651 }))
5652 defer cst.close()
5653 pr, pw := io.Pipe()
5654 defer pw.Close()
5655 req, err := NewRequest("CONNECT", cst.ts.URL, pr)
5656 if err != nil {
5657 t.Fatal(err)
5658 }
5659 req.URL.Opaque = target
5660 res, err := cst.c.Do(req)
5661 if err != nil {
5662 t.Fatal(err)
5663 }
5664 defer res.Body.Close()
5665 if res.StatusCode != 200 {
5666 t.Fatalf("status code = %d; want 200", res.StatusCode)
5667 }
5668 br := bufio.NewReader(res.Body)
5669 for _, str := range []string{"foo", "bar", "baz"} {
5670 fmt.Fprintf(pw, "%s\n", str)
5671 got, err := br.ReadString('\n')
5672 if err != nil {
5673 t.Fatal(err)
5674 }
5675 got = strings.TrimSpace(got)
5676 want := strings.ToUpper(str)
5677 if got != want {
5678 t.Fatalf("got %q; want %q", got, want)
5679 }
5680 }
5681 }
5682
5683 func TestTransportRequestReplayable(t *testing.T) {
5684 someBody := io.NopCloser(strings.NewReader(""))
5685 tests := []struct {
5686 name string
5687 req *Request
5688 want bool
5689 }{
5690 {
5691 name: "GET",
5692 req: &Request{Method: "GET"},
5693 want: true,
5694 },
5695 {
5696 name: "GET_http.NoBody",
5697 req: &Request{Method: "GET", Body: NoBody},
5698 want: true,
5699 },
5700 {
5701 name: "GET_body",
5702 req: &Request{Method: "GET", Body: someBody},
5703 want: false,
5704 },
5705 {
5706 name: "POST",
5707 req: &Request{Method: "POST"},
5708 want: false,
5709 },
5710 {
5711 name: "POST_idempotency-key",
5712 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
5713 want: true,
5714 },
5715 {
5716 name: "POST_x-idempotency-key",
5717 req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
5718 want: true,
5719 },
5720 {
5721 name: "POST_body",
5722 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
5723 want: false,
5724 },
5725 }
5726 for _, tt := range tests {
5727 t.Run(tt.name, func(t *testing.T) {
5728 got := tt.req.ExportIsReplayable()
5729 if got != tt.want {
5730 t.Errorf("replyable = %v; want %v", got, tt.want)
5731 }
5732 })
5733 }
5734 }
5735
5736
5737
5738 type testMockTCPConn struct {
5739 *net.TCPConn
5740
5741 ReadFromCalled bool
5742 }
5743
5744 func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
5745 c.ReadFromCalled = true
5746 return c.TCPConn.ReadFrom(r)
5747 }
5748
5749 func TestTransportRequestWriteRoundTrip(t *testing.T) {
5750 nBytes := int64(1 << 10)
5751 newFileFunc := func() (r io.Reader, done func(), err error) {
5752 f, err := os.CreateTemp("", "net-http-newfilefunc")
5753 if err != nil {
5754 return nil, nil, err
5755 }
5756
5757
5758 if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
5759 return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
5760 }
5761 if _, err := f.Seek(0, 0); err != nil {
5762 return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
5763 }
5764
5765 done = func() {
5766 f.Close()
5767 os.Remove(f.Name())
5768 }
5769
5770 return f, done, nil
5771 }
5772
5773 newBufferFunc := func() (io.Reader, func(), error) {
5774 return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
5775 }
5776
5777 cases := []struct {
5778 name string
5779 readerFunc func() (io.Reader, func(), error)
5780 contentLength int64
5781 expectedReadFrom bool
5782 }{
5783 {
5784 name: "file, length",
5785 readerFunc: newFileFunc,
5786 contentLength: nBytes,
5787 expectedReadFrom: true,
5788 },
5789 {
5790 name: "file, no length",
5791 readerFunc: newFileFunc,
5792 },
5793 {
5794 name: "file, negative length",
5795 readerFunc: newFileFunc,
5796 contentLength: -1,
5797 },
5798 {
5799 name: "buffer",
5800 contentLength: nBytes,
5801 readerFunc: newBufferFunc,
5802 },
5803 {
5804 name: "buffer, no length",
5805 readerFunc: newBufferFunc,
5806 },
5807 {
5808 name: "buffer, length -1",
5809 contentLength: -1,
5810 readerFunc: newBufferFunc,
5811 },
5812 }
5813
5814 for _, tc := range cases {
5815 t.Run(tc.name, func(t *testing.T) {
5816 r, cleanup, err := tc.readerFunc()
5817 if err != nil {
5818 t.Fatal(err)
5819 }
5820 defer cleanup()
5821
5822 tConn := &testMockTCPConn{}
5823 trFunc := func(tr *Transport) {
5824 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
5825 var d net.Dialer
5826 conn, err := d.DialContext(ctx, network, addr)
5827 if err != nil {
5828 return nil, err
5829 }
5830
5831 tcpConn, ok := conn.(*net.TCPConn)
5832 if !ok {
5833 return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
5834 }
5835
5836 tConn.TCPConn = tcpConn
5837 return tConn, nil
5838 }
5839 }
5840
5841 cst := newClientServerTest(
5842 t,
5843 h1Mode,
5844 HandlerFunc(func(w ResponseWriter, r *Request) {
5845 io.Copy(io.Discard, r.Body)
5846 r.Body.Close()
5847 w.WriteHeader(200)
5848 }),
5849 trFunc,
5850 )
5851 defer cst.close()
5852
5853 req, err := NewRequest("PUT", cst.ts.URL, r)
5854 if err != nil {
5855 t.Fatal(err)
5856 }
5857 req.ContentLength = tc.contentLength
5858 req.Header.Set("Content-Type", "application/octet-stream")
5859 resp, err := cst.c.Do(req)
5860 if err != nil {
5861 t.Fatal(err)
5862 }
5863 defer resp.Body.Close()
5864 if resp.StatusCode != 200 {
5865 t.Fatalf("status code = %d; want 200", resp.StatusCode)
5866 }
5867
5868 if !tConn.ReadFromCalled && tc.expectedReadFrom {
5869 t.Fatalf("did not call ReadFrom")
5870 }
5871
5872 if tConn.ReadFromCalled && !tc.expectedReadFrom {
5873 t.Fatalf("ReadFrom was unexpectedly invoked")
5874 }
5875 })
5876 }
5877 }
5878
5879 func TestTransportClone(t *testing.T) {
5880 tr := &Transport{
5881 Proxy: func(*Request) (*url.URL, error) { panic("") },
5882 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
5883 Dial: func(network, addr string) (net.Conn, error) { panic("") },
5884 DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
5885 DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
5886 TLSClientConfig: new(tls.Config),
5887 TLSHandshakeTimeout: time.Second,
5888 DisableKeepAlives: true,
5889 DisableCompression: true,
5890 MaxIdleConns: 1,
5891 MaxIdleConnsPerHost: 1,
5892 MaxConnsPerHost: 1,
5893 IdleConnTimeout: time.Second,
5894 ResponseHeaderTimeout: time.Second,
5895 ExpectContinueTimeout: time.Second,
5896 ProxyConnectHeader: Header{},
5897 GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
5898 MaxResponseHeaderBytes: 1,
5899 ForceAttemptHTTP2: true,
5900 TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
5901 "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
5902 },
5903 ReadBufferSize: 1,
5904 WriteBufferSize: 1,
5905 }
5906 tr2 := tr.Clone()
5907 rv := reflect.ValueOf(tr2).Elem()
5908 rt := rv.Type()
5909 for i := 0; i < rt.NumField(); i++ {
5910 sf := rt.Field(i)
5911 if !token.IsExported(sf.Name) {
5912 continue
5913 }
5914 if rv.Field(i).IsZero() {
5915 t.Errorf("cloned field t2.%s is zero", sf.Name)
5916 }
5917 }
5918
5919 if _, ok := tr2.TLSNextProto["foo"]; !ok {
5920 t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
5921 }
5922
5923
5924 tr = new(Transport)
5925 tr2 = tr.Clone()
5926 if tr2.TLSNextProto != nil {
5927 t.Errorf("Transport.TLSNextProto unexpected non-nil")
5928 }
5929 }
5930
5931 func TestIs408(t *testing.T) {
5932 tests := []struct {
5933 in string
5934 want bool
5935 }{
5936 {"HTTP/1.0 408", true},
5937 {"HTTP/1.1 408", true},
5938 {"HTTP/1.8 408", true},
5939 {"HTTP/2.0 408", false},
5940 {"HTTP/1.1 408 ", true},
5941 {"HTTP/1.1 40", false},
5942 {"http/1.0 408", false},
5943 {"HTTP/1-1 408", false},
5944 }
5945 for _, tt := range tests {
5946 if got := Export_is408Message([]byte(tt.in)); got != tt.want {
5947 t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
5948 }
5949 }
5950 }
5951
5952 func TestTransportIgnores408(t *testing.T) {
5953
5954 defer log.SetOutput(log.Writer())
5955
5956 var logout bytes.Buffer
5957 log.SetOutput(&logout)
5958
5959 defer afterTest(t)
5960 const target = "backend:443"
5961
5962 cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5963 nc, _, err := w.(Hijacker).Hijack()
5964 if err != nil {
5965 t.Error(err)
5966 return
5967 }
5968 defer nc.Close()
5969 nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
5970 nc.Write([]byte("HTTP/1.1 408 bye\r\n"))
5971 }))
5972 defer cst.close()
5973 req, err := NewRequest("GET", cst.ts.URL, nil)
5974 if err != nil {
5975 t.Fatal(err)
5976 }
5977 res, err := cst.c.Do(req)
5978 if err != nil {
5979 t.Fatal(err)
5980 }
5981 slurp, err := io.ReadAll(res.Body)
5982 if err != nil {
5983 t.Fatal(err)
5984 }
5985 if err != nil {
5986 t.Fatal(err)
5987 }
5988 if string(slurp) != "ok" {
5989 t.Fatalf("got %q; want ok", slurp)
5990 }
5991
5992 t0 := time.Now()
5993 for i := 0; i < 50; i++ {
5994 time.Sleep(time.Duration(i) * 5 * time.Millisecond)
5995 if cst.tr.IdleConnKeyCountForTesting() == 0 {
5996 if got := logout.String(); got != "" {
5997 t.Fatalf("expected no log output; got: %s", got)
5998 }
5999 return
6000 }
6001 }
6002 t.Fatalf("timeout after %v waiting for Transport connections to die off", time.Since(t0))
6003 }
6004
6005 func TestInvalidHeaderResponse(t *testing.T) {
6006 setParallel(t)
6007 defer afterTest(t)
6008 cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6009 conn, buf, _ := w.(Hijacker).Hijack()
6010 buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
6011 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6012 "Content-Type: text/html; charset=utf-8\r\n" +
6013 "Content-Length: 0\r\n" +
6014 "Foo : bar\r\n\r\n"))
6015 buf.Flush()
6016 conn.Close()
6017 }))
6018 defer cst.close()
6019 res, err := cst.c.Get(cst.ts.URL)
6020 if err != nil {
6021 t.Fatal(err)
6022 }
6023 defer res.Body.Close()
6024 if v := res.Header.Get("Foo"); v != "" {
6025 t.Errorf(`unexpected "Foo" header: %q`, v)
6026 }
6027 if v := res.Header.Get("Foo "); v != "bar" {
6028 t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
6029 }
6030 }
6031
6032 type bodyCloser bool
6033
6034 func (bc *bodyCloser) Close() error {
6035 *bc = true
6036 return nil
6037 }
6038 func (bc *bodyCloser) Read(b []byte) (n int, err error) {
6039 return 0, io.EOF
6040 }
6041
6042
6043
6044 func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
6045 cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
6046 t.Errorf("Should not have been invoked")
6047 }))
6048 defer cst.Close()
6049
6050 u, _ := url.Parse(cst.URL)
6051
6052 tests := []struct {
6053 name string
6054 req *Request
6055 wantErr string
6056 }{
6057 {
6058 name: "invalid method",
6059 req: &Request{
6060 Method: " ",
6061 URL: u,
6062 },
6063 wantErr: "invalid method",
6064 },
6065 {
6066 name: "nil URL",
6067 req: &Request{
6068 Method: "GET",
6069 },
6070 wantErr: "nil Request.URL",
6071 },
6072 {
6073 name: "invalid header key",
6074 req: &Request{
6075 Method: "GET",
6076 Header: Header{"💡": {"emoji"}},
6077 URL: u,
6078 },
6079 wantErr: "invalid header field name",
6080 },
6081 {
6082 name: "invalid header value",
6083 req: &Request{
6084 Method: "POST",
6085 Header: Header{"key": {"\x19"}},
6086 URL: u,
6087 },
6088 wantErr: "invalid header field value",
6089 },
6090 {
6091 name: "non HTTP(s) scheme",
6092 req: &Request{
6093 Method: "POST",
6094 URL: &url.URL{Scheme: "faux"},
6095 },
6096 wantErr: "unsupported protocol scheme",
6097 },
6098 {
6099 name: "no Host in URL",
6100 req: &Request{
6101 Method: "POST",
6102 URL: &url.URL{Scheme: "http"},
6103 },
6104 wantErr: "no Host",
6105 },
6106 }
6107
6108 for _, tt := range tests {
6109 t.Run(tt.name, func(t *testing.T) {
6110 var bc bodyCloser
6111 req := tt.req
6112 req.Body = &bc
6113 _, err := DefaultClient.Do(tt.req)
6114 if err == nil {
6115 t.Fatal("Expected an error")
6116 }
6117 if !bc {
6118 t.Fatal("Expected body to have been closed")
6119 }
6120 if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
6121 t.Fatalf("Error mismatch\n\t%q\ndoes not contain\n\t%q", g, w)
6122 }
6123 })
6124 }
6125 }
6126
6127
6128
6129 type breakableConn struct {
6130 net.Conn
6131 *brokenState
6132 }
6133
6134 type brokenState struct {
6135 sync.Mutex
6136 broken bool
6137 }
6138
6139 func (w *breakableConn) Write(b []byte) (n int, err error) {
6140 w.Lock()
6141 defer w.Unlock()
6142 if w.broken {
6143 return 0, errors.New("some write error")
6144 }
6145 return w.Conn.Write(b)
6146 }
6147
6148
6149 func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
6150 cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
6151 defer cst.close()
6152
6153 var brokenState brokenState
6154
6155 const numReqs = 5
6156 var numDials, gotConns uint32
6157
6158 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
6159 atomic.AddUint32(&numDials, 1)
6160 c, err := net.Dial(netw, addr)
6161 if err != nil {
6162 t.Errorf("unexpected Dial error: %v", err)
6163 return nil, err
6164 }
6165 return &breakableConn{c, &brokenState}, err
6166 }
6167
6168 for i := 1; i <= numReqs; i++ {
6169 brokenState.Lock()
6170 brokenState.broken = false
6171 brokenState.Unlock()
6172
6173
6174
6175
6176 doBreak := i != numReqs
6177
6178 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6179 GotConn: func(info httptrace.GotConnInfo) {
6180 t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
6181 atomic.AddUint32(&gotConns, 1)
6182 },
6183 TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
6184 brokenState.Lock()
6185 defer brokenState.Unlock()
6186 if doBreak {
6187 brokenState.broken = true
6188 }
6189 },
6190 })
6191 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6192 if err != nil {
6193 t.Fatal(err)
6194 }
6195 _, err = cst.c.Do(req)
6196 if doBreak != (err != nil) {
6197 t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
6198 }
6199 }
6200 if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
6201 t.Errorf("GotConn calls = %v; want %v", got, want)
6202 }
6203 if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
6204 t.Errorf("Dials = %v; want %v", got, want)
6205 }
6206 }
6207
6208
6209
6210
6211
6212 func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
6213 defer afterTest(t)
6214 CondSkipHTTP2(t)
6215
6216 h := HandlerFunc(func(w ResponseWriter, r *Request) {
6217 _, err := w.Write([]byte("foo"))
6218 if err != nil {
6219 t.Fatalf("Write: %v", err)
6220 }
6221 })
6222
6223 ts := httptest.NewUnstartedServer(h)
6224 ts.EnableHTTP2 = true
6225 ts.StartTLS()
6226 defer ts.Close()
6227
6228 c := ts.Client()
6229 tr := c.Transport.(*Transport)
6230 tr.MaxConnsPerHost = 1
6231 if err := ExportHttp2ConfigureTransport(tr); err != nil {
6232 t.Fatalf("ExportHttp2ConfigureTransport: %v", err)
6233 }
6234
6235 errCh := make(chan error, 300)
6236 doReq := func() {
6237 resp, err := c.Get(ts.URL)
6238 if err != nil {
6239 errCh <- fmt.Errorf("request failed: %v", err)
6240 return
6241 }
6242 defer resp.Body.Close()
6243 _, err = io.ReadAll(resp.Body)
6244 if err != nil {
6245 errCh <- fmt.Errorf("read body failed: %v", err)
6246 }
6247 }
6248
6249 var wg sync.WaitGroup
6250 for i := 0; i < 300; i++ {
6251 wg.Add(1)
6252 go func() {
6253 defer wg.Done()
6254 doReq()
6255 }()
6256 }
6257 wg.Wait()
6258 close(errCh)
6259
6260 for err := range errCh {
6261 t.Errorf("error occurred: %v", err)
6262 }
6263 }
6264
6265
6266
6267
6268 func TestAltProtoCancellation(t *testing.T) {
6269 defer afterTest(t)
6270 tr := &Transport{}
6271 c := &Client{
6272 Transport: tr,
6273 Timeout: time.Millisecond,
6274 }
6275 tr.RegisterProtocol("timeout", timeoutProto{})
6276 _, err := c.Get("timeout://bar.com/path")
6277 if err == nil {
6278 t.Error("request unexpectedly succeeded")
6279 } else if !strings.Contains(err.Error(), timeoutProtoErr.Error()) {
6280 t.Errorf("got error %q, does not contain expected string %q", err, timeoutProtoErr)
6281 }
6282 }
6283
6284 var timeoutProtoErr = errors.New("canceled as expected")
6285
6286 type timeoutProto struct{}
6287
6288 func (timeoutProto) RoundTrip(req *Request) (*Response, error) {
6289 select {
6290 case <-req.Cancel:
6291 return nil, timeoutProtoErr
6292 case <-time.After(5 * time.Second):
6293 return nil, errors.New("request was not canceled")
6294 }
6295 }
6296
6297 type roundTripFunc func(r *Request) (*Response, error)
6298
6299 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
6300
6301
6302 func TestIssue32441(t *testing.T) {
6303 defer afterTest(t)
6304 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
6305 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6306 t.Error("body length is zero")
6307 }
6308 }))
6309 defer ts.Close()
6310 c := ts.Client()
6311 c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
6312
6313 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6314 t.Error("body length is zero during round trip")
6315 }
6316 return nil, ErrSkipAltProtocol
6317 }))
6318 if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
6319 t.Error(err)
6320 }
6321 }
6322
6323
6324
6325 func TestTransportRejectsSignInContentLength(t *testing.T) {
6326 cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
6327 w.Header().Set("Content-Length", "+3")
6328 w.Write([]byte("abc"))
6329 }))
6330 defer cst.Close()
6331
6332 c := cst.Client()
6333 res, err := c.Get(cst.URL)
6334 if err == nil || res != nil {
6335 t.Fatal("Expected a non-nil error and a nil http.Response")
6336 }
6337 if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
6338 t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
6339 }
6340 }
6341
6342
6343 type dumpConn struct {
6344 io.Writer
6345 io.Reader
6346 }
6347
6348 func (c *dumpConn) Close() error { return nil }
6349 func (c *dumpConn) LocalAddr() net.Addr { return nil }
6350 func (c *dumpConn) RemoteAddr() net.Addr { return nil }
6351 func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
6352 func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
6353 func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
6354
6355
6356
6357 type delegateReader struct {
6358 c chan io.Reader
6359 r io.Reader
6360 }
6361
6362 func (r *delegateReader) Read(p []byte) (int, error) {
6363 if r.r == nil {
6364 var ok bool
6365 if r.r, ok = <-r.c; !ok {
6366 return 0, errors.New("delegate closed")
6367 }
6368 }
6369 return r.r.Read(p)
6370 }
6371
6372 func testTransportRace(req *Request) {
6373 save := req.Body
6374 pr, pw := io.Pipe()
6375 defer pr.Close()
6376 defer pw.Close()
6377 dr := &delegateReader{c: make(chan io.Reader)}
6378
6379 t := &Transport{
6380 Dial: func(net, addr string) (net.Conn, error) {
6381 return &dumpConn{pw, dr}, nil
6382 },
6383 }
6384 defer t.CloseIdleConnections()
6385
6386 quitReadCh := make(chan struct{})
6387
6388 go func() {
6389 defer close(quitReadCh)
6390
6391 req, err := ReadRequest(bufio.NewReader(pr))
6392 if err == nil {
6393
6394
6395 io.Copy(io.Discard, req.Body)
6396 req.Body.Close()
6397 }
6398 select {
6399 case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
6400 case quitReadCh <- struct{}{}:
6401
6402 close(dr.c)
6403 }
6404 }()
6405
6406 t.RoundTrip(req)
6407
6408
6409
6410 pw.Close()
6411 <-quitReadCh
6412
6413 req.Body = save
6414 }
6415
6416
6417
6418
6419
6420 func TestErrorWriteLoopRace(t *testing.T) {
6421 if testing.Short() {
6422 return
6423 }
6424 t.Parallel()
6425 for i := 0; i < 1000; i++ {
6426 delay := time.Duration(mrand.Intn(5)) * time.Millisecond
6427 ctx, cancel := context.WithTimeout(context.Background(), delay)
6428 defer cancel()
6429
6430 r := bytes.NewBuffer(make([]byte, 10000))
6431 req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
6432 if err != nil {
6433 t.Fatal(err)
6434 }
6435
6436 testTransportRace(req)
6437 }
6438 }
6439
6440
6441
6442
6443 func TestCancelRequestWhenSharingConnection(t *testing.T) {
6444 reqc := make(chan chan struct{}, 2)
6445 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, req *Request) {
6446 ch := make(chan struct{}, 1)
6447 reqc <- ch
6448 <-ch
6449 w.Header().Add("Content-Length", "0")
6450 }))
6451 defer ts.Close()
6452
6453 client := ts.Client()
6454 transport := client.Transport.(*Transport)
6455 transport.MaxIdleConns = 1
6456 transport.MaxConnsPerHost = 1
6457
6458 var wg sync.WaitGroup
6459
6460 wg.Add(1)
6461 putidlec := make(chan chan struct{})
6462 go func() {
6463 defer wg.Done()
6464 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6465 PutIdleConn: func(error) {
6466
6467
6468 ch := make(chan struct{})
6469 putidlec <- ch
6470 <-ch
6471 },
6472 })
6473 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
6474 res, err := client.Do(req)
6475 if err == nil {
6476 res.Body.Close()
6477 }
6478 if err != nil {
6479 t.Errorf("request 1: got err %v, want nil", err)
6480 }
6481 }()
6482
6483
6484
6485 r1c := <-reqc
6486 close(r1c)
6487 idlec := <-putidlec
6488
6489 wg.Add(1)
6490 cancelctx, cancel := context.WithCancel(context.Background())
6491 go func() {
6492 defer wg.Done()
6493 req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
6494 res, err := client.Do(req)
6495 if err == nil {
6496 res.Body.Close()
6497 }
6498 if !errors.Is(err, context.Canceled) {
6499 t.Errorf("request 2: got err %v, want Canceled", err)
6500 }
6501 }()
6502
6503
6504
6505 r2c := <-reqc
6506 cancel()
6507
6508
6509 time.Sleep(1 * time.Millisecond)
6510 close(idlec)
6511
6512 close(r2c)
6513 wg.Wait()
6514 }
6515
6516 func TestHandlerAbortRacesBodyRead(t *testing.T) {
6517 setParallel(t)
6518 defer afterTest(t)
6519
6520 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
6521 go io.Copy(io.Discard, req.Body)
6522 panic(ErrAbortHandler)
6523 }))
6524 defer ts.Close()
6525
6526 var wg sync.WaitGroup
6527 for i := 0; i < 2; i++ {
6528 wg.Add(1)
6529 go func() {
6530 defer wg.Done()
6531 for j := 0; j < 10; j++ {
6532 const reqLen = 6 * 1024 * 1024
6533 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
6534 req.ContentLength = reqLen
6535 resp, _ := ts.Client().Transport.RoundTrip(req)
6536 if resp != nil {
6537 resp.Body.Close()
6538 }
6539 }
6540 }()
6541 }
6542 wg.Wait()
6543 }
6544
View as plain text