// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Tests for transport.go. // // More tests are in clientserver_test.go (for things testing both client & server for both // HTTP/1 and HTTP/2). This package http_test import ( "bufio" "bytes" "compress/gzip" "context" "crypto/rand" "crypto/tls" "crypto/x509" "encoding/binary" "errors" "fmt" "go/token" "internal/nettrace" "io" "log" mrand "math/rand" "net" . "net/http" "net/http/httptest" "net/http/httptrace" "net/http/httputil" "net/http/internal/testcert" "net/textproto" "net/url" "os" "reflect" "runtime" "strconv" "strings" "sync" "sync/atomic" "testing" "testing/iotest" "time" "golang.org/x/net/http/httpguts" ) // TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close // and then verify that the final 2 responses get errors back. // hostPortHandler writes back the client's "host:port". var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) { if r.FormValue("close") == "true" { w.Header().Set("Connection", "close") } w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close)) w.Write([]byte(r.RemoteAddr)) }) // testCloseConn is a net.Conn tracked by a testConnSet. type testCloseConn struct { net.Conn set *testConnSet } func (c *testCloseConn) Close() error { c.set.remove(c) return c.Conn.Close() } // testConnSet tracks a set of TCP connections and whether they've // been closed. type testConnSet struct { t *testing.T mu sync.Mutex // guards closed and list closed map[net.Conn]bool list []net.Conn // in order created } func (tcs *testConnSet) insert(c net.Conn) { tcs.mu.Lock() defer tcs.mu.Unlock() tcs.closed[c] = false tcs.list = append(tcs.list, c) } func (tcs *testConnSet) remove(c net.Conn) { tcs.mu.Lock() defer tcs.mu.Unlock() tcs.closed[c] = true } // some tests use this to manage raw tcp connections for later inspection func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) { connSet := &testConnSet{ t: t, closed: make(map[net.Conn]bool), } dial := func(n, addr string) (net.Conn, error) { c, err := net.Dial(n, addr) if err != nil { return nil, err } tc := &testCloseConn{c, connSet} connSet.insert(tc) return tc, nil } return connSet, dial } func (tcs *testConnSet) check(t *testing.T) { tcs.mu.Lock() defer tcs.mu.Unlock() for i := 4; i >= 0; i-- { for i, c := range tcs.list { if tcs.closed[c] { continue } if i != 0 { tcs.mu.Unlock() time.Sleep(50 * time.Millisecond) tcs.mu.Lock() continue } t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list)) } } } func TestReuseRequest(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("{}")) })) defer ts.Close() c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) res, err := c.Do(req) if err != nil { t.Fatal(err) } err = res.Body.Close() if err != nil { t.Fatal(err) } res, err = c.Do(req) if err != nil { t.Fatal(err) } err = res.Body.Close() if err != nil { t.Fatal(err) } } // Two subsequent requests and verify their response is the same. // The response from the server is our own IP:port func TestTransportKeepAlives(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() c := ts.Client() for _, disableKeepAlive := range []bool{false, true} { c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive fetch := func(n int) string { res, err := c.Get(ts.URL) if err != nil { t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err) } body, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err) } return string(body) } body1 := fetch(1) body2 := fetch(2) bodiesDiffer := body1 != body2 if bodiesDiffer != disableKeepAlive { t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", disableKeepAlive, bodiesDiffer, body1, body2) } } } func TestTransportConnectionCloseOnResponse(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() connSet, testDial := makeTestDial(t) c := ts.Client() tr := c.Transport.(*Transport) tr.Dial = testDial for _, connectionClose := range []bool{false, true} { fetch := func(n int) string { req := new(Request) var err error req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose)) if err != nil { t.Fatalf("URL parse error: %v", err) } req.Method = "GET" req.Proto = "HTTP/1.1" req.ProtoMajor = 1 req.ProtoMinor = 1 res, err := c.Do(req) if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) } defer res.Body.Close() body, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) } return string(body) } body1 := fetch(1) body2 := fetch(2) bodiesDiffer := body1 != body2 if bodiesDiffer != connectionClose { t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", connectionClose, bodiesDiffer, body1, body2) } tr.CloseIdleConnections() } connSet.check(t) } func TestTransportConnectionCloseOnRequest(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() connSet, testDial := makeTestDial(t) c := ts.Client() tr := c.Transport.(*Transport) tr.Dial = testDial for _, connectionClose := range []bool{false, true} { fetch := func(n int) string { req := new(Request) var err error req.URL, err = url.Parse(ts.URL) if err != nil { t.Fatalf("URL parse error: %v", err) } req.Method = "GET" req.Proto = "HTTP/1.1" req.ProtoMajor = 1 req.ProtoMinor = 1 req.Close = connectionClose res, err := c.Do(req) if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) } if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(connectionClose); got != want { t.Errorf("For connectionClose = %v; handler's X-Saw-Close was %v; want %v", connectionClose, got, !connectionClose) } body, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) } return string(body) } body1 := fetch(1) body2 := fetch(2) bodiesDiffer := body1 != body2 if bodiesDiffer != connectionClose { t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", connectionClose, bodiesDiffer, body1, body2) } tr.CloseIdleConnections() } connSet.check(t) } // if the Transport's DisableKeepAlives is set, all requests should // send Connection: close. // HTTP/1-only (Connection: close doesn't exist in h2) func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() c := ts.Client() c.Transport.(*Transport).DisableKeepAlives = true res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } res.Body.Close() if res.Header.Get("X-Saw-Close") != "true" { t.Errorf("handler didn't see Connection: close ") } } // Test that Transport only sends one "Connection: close", regardless of // how "close" was indicated. func TestTransportRespectRequestWantsClose(t *testing.T) { tests := []struct { disableKeepAlives bool close bool }{ {disableKeepAlives: false, close: false}, {disableKeepAlives: false, close: true}, {disableKeepAlives: true, close: false}, {disableKeepAlives: true, close: true}, } for _, tc := range tests { t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close), func(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() c := ts.Client() c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) } count := 0 trace := &httptrace.ClientTrace{ WroteHeaderField: func(key string, field []string) { if key != "Connection" { return } if httpguts.HeaderValuesContainsToken(field, "close") { count += 1 } }, } req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) req.Close = tc.close res, err := c.Do(req) if err != nil { t.Fatal(err) } defer res.Body.Close() if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want { t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count) } }) } } func TestTransportIdleCacheKeys(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() c := ts.Client() tr := c.Transport.(*Transport) if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) } resp, err := c.Get(ts.URL) if err != nil { t.Error(err) } io.ReadAll(resp.Body) keys := tr.IdleConnKeysForTesting() if e, g := 1, len(keys); e != g { t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g) } if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e { t.Errorf("Expected idle cache key %q; got %q", e, keys[0]) } tr.CloseIdleConnections() if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) } } // Tests that the HTTP transport re-uses connections when a client // reads to the end of a response Body without closing it. func TestTransportReadToEndReusesConn(t *testing.T) { defer afterTest(t) const msg = "foobar" var addrSeen map[string]int ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { addrSeen[r.RemoteAddr]++ if r.URL.Path == "/chunked/" { w.WriteHeader(200) w.(Flusher).Flush() } else { w.Header().Set("Content-Length", strconv.Itoa(len(msg))) w.WriteHeader(200) } w.Write([]byte(msg)) })) defer ts.Close() buf := make([]byte, len(msg)) for pi, path := range []string{"/content-length/", "/chunked/"} { wantLen := []int{len(msg), -1}[pi] addrSeen = make(map[string]int) for i := 0; i < 3; i++ { res, err := Get(ts.URL + path) if err != nil { t.Errorf("Get %s: %v", path, err) continue } // We want to close this body eventually (before the // defer afterTest at top runs), but not before the // len(addrSeen) check at the bottom of this test, // since Closing this early in the loop would risk // making connections be re-used for the wrong reason. defer res.Body.Close() if res.ContentLength != int64(wantLen) { t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen) } n, err := res.Body.Read(buf) if n != len(msg) || err != io.EOF { t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg)) } } if len(addrSeen) != 1 { t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen)) } } } func TestTransportMaxPerHostIdleConns(t *testing.T) { defer afterTest(t) stop := make(chan struct{}) // stop marks the exit of main Test goroutine defer close(stop) resch := make(chan string) gotReq := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { gotReq <- true var msg string select { case <-stop: return case msg = <-resch: } _, err := w.Write([]byte(msg)) if err != nil { t.Errorf("Write: %v", err) return } })) defer ts.Close() c := ts.Client() tr := c.Transport.(*Transport) maxIdleConnsPerHost := 2 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost // Start 3 outstanding requests and wait for the server to get them. // Their responses will hang until we write to resch, though. donech := make(chan bool) doReq := func() { defer func() { select { case <-stop: return case donech <- t.Failed(): } }() resp, err := c.Get(ts.URL) if err != nil { t.Error(err) return } if _, err := io.ReadAll(resp.Body); err != nil { t.Errorf("ReadAll: %v", err) return } } go doReq() <-gotReq go doReq() <-gotReq go doReq() <-gotReq if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g) } resch <- "res1" <-donech keys := tr.IdleConnKeysForTesting() if e, g := 1, len(keys); e != g { t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g) } addr := ts.Listener.Addr().String() cacheKey := "|http|" + addr if keys[0] != cacheKey { t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0]) } if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g { t.Errorf("after first response, expected %d idle conns; got %d", e, g) } resch <- "res2" <-donech if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w { t.Errorf("after second response, idle conns = %d; want %d", g, w) } resch <- "res3" <-donech if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w { t.Errorf("after third response, idle conns = %d; want %d", g, w) } } func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("foo")) if err != nil { t.Fatalf("Write: %v", err) } })) defer ts.Close() c := ts.Client() tr := c.Transport.(*Transport) dialStarted := make(chan struct{}) stallDial := make(chan struct{}) tr.Dial = func(network, addr string) (net.Conn, error) { dialStarted <- struct{}{} <-stallDial return net.Dial(network, addr) } tr.DisableKeepAlives = true tr.MaxConnsPerHost = 1 preDial := make(chan struct{}) reqComplete := make(chan struct{}) doReq := func(reqId string) { req, _ := NewRequest("GET", ts.URL, nil) trace := &httptrace.ClientTrace{ GetConn: func(hostPort string) { preDial <- struct{}{} }, } req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) resp, err := tr.RoundTrip(req) if err != nil { t.Errorf("unexpected error for request %s: %v", reqId, err) } _, err = io.ReadAll(resp.Body) if err != nil { t.Errorf("unexpected error for request %s: %v", reqId, err) } reqComplete <- struct{}{} } // get req1 to dial-in-progress go doReq("req1") <-preDial <-dialStarted // get req2 to waiting on conns per host to go down below max go doReq("req2") <-preDial select { case <-dialStarted: t.Error("req2 dial started while req1 dial in progress") return default: } // let req1 complete stallDial <- struct{}{} <-reqComplete // let req2 complete <-dialStarted stallDial <- struct{}{} <-reqComplete } func TestTransportMaxConnsPerHost(t *testing.T) { defer afterTest(t) CondSkipHTTP2(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("foo")) if err != nil { t.Fatalf("Write: %v", err) } }) testMaxConns := func(scheme string, ts *httptest.Server) { defer ts.Close() c := ts.Client() tr := c.Transport.(*Transport) tr.MaxConnsPerHost = 1 if err := ExportHttp2ConfigureTransport(tr); err != nil { t.Fatalf("ExportHttp2ConfigureTransport: %v", err) } mu := sync.Mutex{} var conns []net.Conn var dialCnt, gotConnCnt, tlsHandshakeCnt int32 tr.Dial = func(network, addr string) (net.Conn, error) { atomic.AddInt32(&dialCnt, 1) c, err := net.Dial(network, addr) mu.Lock() defer mu.Unlock() conns = append(conns, c) return c, err } doReq := func() { trace := &httptrace.ClientTrace{ GotConn: func(connInfo httptrace.GotConnInfo) { if !connInfo.Reused { atomic.AddInt32(&gotConnCnt, 1) } }, TLSHandshakeStart: func() { atomic.AddInt32(&tlsHandshakeCnt, 1) }, } req, _ := NewRequest("GET", ts.URL, nil) req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) resp, err := c.Do(req) if err != nil { t.Fatalf("request failed: %v", err) } defer resp.Body.Close() _, err = io.ReadAll(resp.Body) if err != nil { t.Fatalf("read body failed: %v", err) } } wg := sync.WaitGroup{} for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() doReq() }() } wg.Wait() expected := int32(tr.MaxConnsPerHost) if dialCnt != expected { t.Errorf("round 1: too many dials (%s): %d != %d", scheme, dialCnt, expected) } if gotConnCnt != expected { t.Errorf("round 1: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) } if ts.TLS != nil && tlsHandshakeCnt != expected { t.Errorf("round 1: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) } if t.Failed() { t.FailNow() } mu.Lock() for _, c := range conns { c.Close() } conns = nil mu.Unlock() tr.CloseIdleConnections() doReq() expected++ if dialCnt != expected { t.Errorf("round 2: too many dials (%s): %d", scheme, dialCnt) } if gotConnCnt != expected { t.Errorf("round 2: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) } if ts.TLS != nil && tlsHandshakeCnt != expected { t.Errorf("round 2: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) } } testMaxConns("http", httptest.NewServer(h)) testMaxConns("https", httptest.NewTLSServer(h)) ts := httptest.NewUnstartedServer(h) ts.TLS = &tls.Config{NextProtos: []string{"h2"}} ts.StartTLS() testMaxConns("http2", ts) } func TestTransportRemovesDeadIdleConnections(t *testing.T) { setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, r.RemoteAddr) })) defer ts.Close() c := ts.Client() tr := c.Transport.(*Transport) doReq := func(name string) string { // Do a POST instead of a GET to prevent the Transport's // idempotent request retry logic from kicking in... res, err := c.Post(ts.URL, "", nil) if err != nil { t.Fatalf("%s: %v", name, err) } if res.StatusCode != 200 { t.Fatalf("%s: %v", name, res.Status) } defer res.Body.Close() slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("%s: %v", name, err) } return string(slurp) } first := doReq("first") keys1 := tr.IdleConnKeysForTesting() ts.CloseClientConnections() var keys2 []string if !waitCondition(3*time.Second, 50*time.Millisecond, func() bool { keys2 = tr.IdleConnKeysForTesting() return len(keys2) == 0 }) { t.Fatalf("Transport didn't notice idle connection's death.\nbefore: %q\n after: %q\n", keys1, keys2) } second := doReq("second") if first == second { t.Errorf("expected a different connection between requests. got %q both times", first) } } // Test that the Transport notices when a server hangs up on its // unexpectedly (a keep-alive connection is closed). func TestTransportServerClosingUnexpectedly(t *testing.T) { setParallel(t) defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() c := ts.Client() fetch := func(n, retries int) string { condFatalf := func(format string, arg ...any) { if retries <= 0 { t.Fatalf(format, arg...) } t.Logf("retrying shortly after expected error: "+format, arg...) time.Sleep(time.Second / time.Duration(retries)) } for retries >= 0 { retries-- res, err := c.Get(ts.URL) if err != nil { condFatalf("error in req #%d, GET: %v", n, err) continue } body, err := io.ReadAll(res.Body) if err != nil { condFatalf("error in req #%d, ReadAll: %v", n, err) continue } res.Body.Close() return string(body) } panic("unreachable") } body1 := fetch(1, 0) body2 := fetch(2, 0) // Close all the idle connections in a way that's similar to // the server hanging up on us. We don't use // httptest.Server.CloseClientConnections because it's // best-effort and stops blocking after 5 seconds. On a loaded // machine running many tests concurrently it's possible for // that method to be async and cause the body3 fetch below to // run on an old connection. This function is synchronous. ExportCloseTransportConnsAbruptly(c.Transport.(*Transport)) body3 := fetch(3, 5) if body1 != body2 { t.Errorf("expected body1 and body2 to be equal") } if body2 == body3 { t.Errorf("expected body2 and body3 to be different") } } // Test for https://golang.org/issue/2616 (appropriate issue number) // This fails pretty reliably with GOMAXPROCS=100 or something high. func TestStressSurpriseServerCloses(t *testing.T) { defer afterTest(t) if testing.Short() { t.Skip("skipping test in short mode") } ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "5") w.Header().Set("Content-Type", "text/plain") w.Write([]byte("Hello")) w.(Flusher).Flush() conn, buf, _ := w.(Hijacker).Hijack() buf.Flush() conn.Close() })) defer ts.Close() c := ts.Client() // Do a bunch of traffic from different goroutines. Send to activityc // after each request completes, regardless of whether it failed. // If these are too high, OS X exhausts its ephemeral ports // and hangs waiting for them to transition TCP states. That's // not what we want to test. TODO(bradfitz): use an io.Pipe // dialer for this test instead? const ( numClients = 20 reqsPerClient = 25 ) activityc := make(chan bool) for i := 0; i < numClients; i++ { go func() { for i := 0; i < reqsPerClient; i++ { res, err := c.Get(ts.URL) if err == nil { // We expect errors since the server is // hanging up on us after telling us to // send more requests, so we don't // actually care what the error is. // But we want to close the body in cases // where we won the race. res.Body.Close() } if !<-activityc { // Receives false when close(activityc) is executed return } } }() } // Make sure all the request come back, one way or another. for i := 0; i < numClients*reqsPerClient; i++ { select { case activityc <- true: case <-time.After(5 * time.Second): close(activityc) t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile") } } } // TestTransportHeadResponses verifies that we deal with Content-Lengths // with no bodies properly func TestTransportHeadResponses(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) } w.Header().Set("Content-Length", "123") w.WriteHeader(200) })) defer ts.Close() c := ts.Client() for i := 0; i < 2; i++ { res, err := c.Head(ts.URL) if err != nil { t.Errorf("error on loop %d: %v", i, err) continue } if e, g := "123", res.Header.Get("Content-Length"); e != g { t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) } if e, g := int64(123), res.ContentLength; e != g { t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) } if all, err := io.ReadAll(res.Body); err != nil { t.Errorf("loop %d: Body ReadAll: %v", i, err) } else if len(all) != 0 { t.Errorf("Bogus body %q", all) } } } // TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding // on responses to HEAD requests. func TestTransportHeadChunkedResponse(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) } w.Header().Set("Transfer-Encoding", "chunked") // client should ignore w.Header().Set("x-client-ipport", r.RemoteAddr) w.WriteHeader(200) })) defer ts.Close() c := ts.Client() // Ensure that we wait for the readLoop to complete before // calling Head again didRead := make(chan bool) SetReadLoopBeforeNextReadHook(func() { didRead <- true }) defer SetReadLoopBeforeNextReadHook(nil) res1, err := c.Head(ts.URL) <-didRead if err != nil { t.Fatalf("request 1 error: %v", err) } res2, err := c.Head(ts.URL) <-didRead if err != nil { t.Fatalf("request 2 error: %v", err) } if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 { t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2) } } var roundTripTests = []struct { accept string expectAccept string compressed bool }{ // Requests with no accept-encoding header use transparent compression {"", "gzip", false}, // Requests with other accept-encoding should pass through unmodified {"foo", "foo", false}, // Requests with accept-encoding == gzip should be passed through {"gzip", "gzip", true}, } // Test that the modification made to the Request by the RoundTripper is cleaned up func TestRoundTripGzip(t *testing.T) { setParallel(t) defer afterTest(t) const responseBody = "test response body" ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { accept := req.Header.Get("Accept-Encoding") if expect := req.FormValue("expect_accept"); accept != expect { t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q", req.FormValue("testnum"), accept, expect) } if accept == "gzip" { rw.Header().Set("Content-Encoding", "gzip") gz := gzip.NewWriter(rw) gz.Write([]byte(responseBody)) gz.Close() } else { rw.Header().Set("Content-Encoding", accept) rw.Write([]byte(responseBody)) } })) defer ts.Close() tr := ts.Client().Transport.(*Transport) for i, test := range roundTripTests { // Test basic request (no accept-encoding) req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil) if test.accept != "" { req.Header.Set("Accept-Encoding", test.accept) } res, err := tr.RoundTrip(req) if err != nil { t.Errorf("%d. RoundTrip: %v", i, err) continue } var body []byte if test.compressed { var r *gzip.Reader r, err = gzip.NewReader(res.Body) if err != nil { t.Errorf("%d. gzip NewReader: %v", i, err) continue } body, err = io.ReadAll(r) res.Body.Close() } else { body, err = io.ReadAll(res.Body) } if err != nil { t.Errorf("%d. Error: %q", i, err) continue } if g, e := string(body), responseBody; g != e { t.Errorf("%d. body = %q; want %q", i, g, e) } if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e { t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e) } if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e { t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e) } } } func TestTransportGzip(t *testing.T) { setParallel(t) defer afterTest(t) const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" const nRandBytes = 1024 * 1024 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { if req.Method == "HEAD" { if g := req.Header.Get("Accept-Encoding"); g != "" { t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g) } return } if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e { t.Errorf("Accept-Encoding = %q, want %q", g, e) } rw.Header().Set("Content-Encoding", "gzip") var w io.Writer = rw var buf bytes.Buffer if req.FormValue("chunked") == "0" { w = &buf defer io.Copy(rw, &buf) defer func() { rw.Header().Set("Content-Length", strconv.Itoa(buf.Len())) }() } gz := gzip.NewWriter(w) gz.Write([]byte(testString)) if req.FormValue("body") == "large" { io.CopyN(gz, rand.Reader, nRandBytes) } gz.Close() })) defer ts.Close() c := ts.Client() for _, chunked := range []string{"1", "0"} { // First fetch something large, but only read some of it. res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked) if err != nil { t.Fatalf("large get: %v", err) } buf := make([]byte, len(testString)) n, err := io.ReadFull(res.Body, buf) if err != nil { t.Fatalf("partial read of large response: size=%d, %v", n, err) } if e, g := testString, string(buf); e != g { t.Errorf("partial read got %q, expected %q", g, e) } res.Body.Close() // Read on the body, even though it's closed n, err = res.Body.Read(buf) if n != 0 || err == nil { t.Errorf("expected error post-closed large Read; got = %d, %v", n, err) } // Then something small. res, err = c.Get(ts.URL + "/?chunked=" + chunked) if err != nil { t.Fatal(err) } body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } if g, e := string(body), testString; g != e { t.Fatalf("body = %q; want %q", g, e) } if g, e := res.Header.Get("Content-Encoding"), ""; g != e { t.Fatalf("Content-Encoding = %q; want %q", g, e) } // Read on the body after it's been fully read: n, err = res.Body.Read(buf) if n != 0 || err == nil { t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err) } res.Body.Close() n, err = res.Body.Read(buf) if n != 0 || err == nil { t.Errorf("expected Read error after Close; got %d, %v", n, err) } } // And a HEAD request too, because they're always weird. res, err := c.Head(ts.URL) if err != nil { t.Fatalf("Head: %v", err) } if res.StatusCode != 200 { t.Errorf("Head status=%d; want=200", res.StatusCode) } } // If a request has Expect:100-continue header, the request blocks sending body until the first response. // Premature consumption of the request body should not be occurred. func TestTransportExpect100Continue(t *testing.T) { setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { switch req.URL.Path { case "/100": // This endpoint implicitly responds 100 Continue and reads body. if _, err := io.Copy(io.Discard, req.Body); err != nil { t.Error("Failed to read Body", err) } rw.WriteHeader(StatusOK) case "/200": // Go 1.5 adds Connection: close header if the client expect // continue but not entire request body is consumed. rw.WriteHeader(StatusOK) case "/500": rw.WriteHeader(StatusInternalServerError) case "/keepalive": // This hijacked endpoint responds error without Connection:close. _, bufrw, err := rw.(Hijacker).Hijack() if err != nil { log.Fatal(err) } bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n") bufrw.WriteString("Content-Length: 0\r\n\r\n") bufrw.Flush() case "/timeout": // This endpoint tries to read body without 100 (Continue) response. // After ExpectContinueTimeout, the reading will be started. conn, bufrw, err := rw.(Hijacker).Hijack() if err != nil { log.Fatal(err) } if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil { t.Error("Failed to read Body", err) } bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n") bufrw.Flush() conn.Close() } })) defer ts.Close() tests := []struct { path string body []byte sent int status int }{ {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent. {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent. {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent. {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Although without Connection:close, body isn't sent. {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent. } c := ts.Client() for i, v := range tests { tr := &Transport{ ExpectContinueTimeout: 2 * time.Second, } defer tr.CloseIdleConnections() c.Transport = tr body := bytes.NewReader(v.body) req, err := NewRequest("PUT", ts.URL+v.path, body) if err != nil { t.Fatal(err) } req.Header.Set("Expect", "100-continue") req.ContentLength = int64(len(v.body)) resp, err := c.Do(req) if err != nil { t.Fatal(err) } resp.Body.Close() sent := len(v.body) - body.Len() if v.status != resp.StatusCode { t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path) } if v.sent != sent { t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path) } } } func TestSOCKS5Proxy(t *testing.T) { defer afterTest(t) ch := make(chan string, 1) l := newLocalListener(t) defer l.Close() defer close(ch) proxy := func(t *testing.T) { s, err := l.Accept() if err != nil { t.Errorf("socks5 proxy Accept(): %v", err) return } defer s.Close() var buf [22]byte if _, err := io.ReadFull(s, buf[:3]); err != nil { t.Errorf("socks5 proxy initial read: %v", err) return } if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want) return } if _, err := s.Write([]byte{5, 0}); err != nil { t.Errorf("socks5 proxy initial write: %v", err) return } if _, err := io.ReadFull(s, buf[:4]); err != nil { t.Errorf("socks5 proxy second read: %v", err) return } if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want) return } var ipLen int switch buf[3] { case 1: ipLen = net.IPv4len case 4: ipLen = net.IPv6len default: t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4]) return } if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil { t.Errorf("socks5 proxy address read: %v", err) return } ip := net.IP(buf[4 : ipLen+4]) port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6]) copy(buf[:3], []byte{5, 0, 0}) if _, err := s.Write(buf[:ipLen+6]); err != nil { t.Errorf("socks5 proxy connect write: %v", err) return } ch <- fmt.Sprintf("proxy for %s:%d", ip, port) // Implement proxying. targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port))) targetConn, err := net.Dial("tcp", targetHost) if err != nil { t.Errorf("net.Dial failed") return } go io.Copy(targetConn, s) io.Copy(s, targetConn) // Wait for the client to close the socket. targetConn.Close() } pu, err := url.Parse("socks5://" + l.Addr().String()) if err != nil { t.Fatal(err) } sentinelHeader := "X-Sentinel" sentinelValue := "12345" h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set(sentinelHeader, sentinelValue) }) for _, useTLS := range []bool{false, true} { t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) { var ts *httptest.Server if useTLS { ts = httptest.NewTLSServer(h) } else { ts = httptest.NewServer(h) } go proxy(t) c := ts.Client() c.Transport.(*Transport).Proxy = ProxyURL(pu) r, err := c.Head(ts.URL) if err != nil { t.Fatal(err) } if r.Header.Get(sentinelHeader) != sentinelValue { t.Errorf("Failed to retrieve sentinel value") } var got string select { case got = <-ch: case <-time.After(5 * time.Second): t.Fatal("timeout connecting to socks5 proxy") } ts.Close() tsu, err := url.Parse(ts.URL) if err != nil { t.Fatal(err) } want := "proxy for " + tsu.Host if got != want { t.Errorf("got %q, want %q", got, want) } }) } } func TestTransportProxy(t *testing.T) { defer afterTest(t) testCases := []struct{ httpsSite, httpsProxy bool }{ {false, false}, {false, true}, {true, false}, {true, true}, } for _, testCase := range testCases { httpsSite := testCase.httpsSite httpsProxy := testCase.httpsProxy t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) { siteCh := make(chan *Request, 1) h1 := HandlerFunc(func(w ResponseWriter, r *Request) { siteCh <- r }) proxyCh := make(chan *Request, 1) h2 := HandlerFunc(func(w ResponseWriter, r *Request) { proxyCh <- r // Implement an entire CONNECT proxy if r.Method == "CONNECT" { hijacker, ok := w.(Hijacker) if !ok { t.Errorf("hijack not allowed") return } clientConn, _, err := hijacker.Hijack() if err != nil { t.Errorf("hijacking failed") return } res := &Response{ StatusCode: StatusOK, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Header: make(Header), } targetConn, err := net.Dial("tcp", r.URL.Host) if err != nil { t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) return } if err := res.Write(clientConn); err != nil { t.Errorf("Writing 200 OK failed: %v", err) return } go io.Copy(targetConn, clientConn) go func() { io.Copy(clientConn, targetConn) targetConn.Close() }() } }) var ts *httptest.Server if httpsSite { ts = httptest.NewTLSServer(h1) } else { ts = httptest.NewServer(h1) } var proxy *httptest.Server if httpsProxy { proxy = httptest.NewTLSServer(h2) } else { proxy = httptest.NewServer(h2) } pu, err := url.Parse(proxy.URL) if err != nil { t.Fatal(err) } // If neither server is HTTPS or both are, then c may be derived from either. // If only one server is HTTPS, c must be derived from that server in order // to ensure that it is configured to use the fake root CA from testcert.go. c := proxy.Client() if httpsSite { c = ts.Client() } c.Transport.(*Transport).Proxy = ProxyURL(pu) if _, err := c.Head(ts.URL); err != nil { t.Error(err) } var got *Request select { case got = <-proxyCh: case <-time.After(5 * time.Second): t.Fatal("timeout connecting to http proxy") } c.Transport.(*Transport).CloseIdleConnections() ts.Close() proxy.Close() if httpsSite { // First message should be a CONNECT, asking for a socket to the real server, if got.Method != "CONNECT" { t.Errorf("Wrong method for secure proxying: %q", got.Method) } gotHost := got.URL.Host pu, err := url.Parse(ts.URL) if err != nil { t.Fatal("Invalid site URL") } if wantHost := pu.Host; gotHost != wantHost { t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost) } // The next message on the channel should be from the site's server. next := <-siteCh if next.Method != "HEAD" { t.Errorf("Wrong method at destination: %s", next.Method) } if nextURL := next.URL.String(); nextURL != "/" { t.Errorf("Wrong URL at destination: %s", nextURL) } } else { if got.Method != "HEAD" { t.Errorf("Wrong method for destination: %q", got.Method) } gotURL := got.URL.String() wantURL := ts.URL + "/" if gotURL != wantURL { t.Errorf("Got URL %q, want %q", gotURL, wantURL) } } }) } } // Issue 28012: verify that the Transport closes its TCP connection to http proxies // when they're slow to reply to HTTPS CONNECT responses. func TestTransportProxyHTTPSConnectLeak(t *testing.T) { setParallel(t) defer afterTest(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() ln := newLocalListener(t) defer ln.Close() listenerDone := make(chan struct{}) go func() { defer close(listenerDone) c, err := ln.Accept() if err != nil { t.Errorf("Accept: %v", err) return } defer c.Close() // Read the CONNECT request br := bufio.NewReader(c) cr, err := ReadRequest(br) if err != nil { t.Errorf("proxy server failed to read CONNECT request") return } if cr.Method != "CONNECT" { t.Errorf("unexpected method %q", cr.Method) return } // Now hang and never write a response; instead, cancel the request and wait // for the client to close. // (Prior to Issue 28012 being fixed, we never closed.) cancel() var buf [1]byte _, err = br.Read(buf[:]) if err != io.EOF { t.Errorf("proxy server Read err = %v; want EOF", err) } return }() c := &Client{ Transport: &Transport{ Proxy: func(*Request) (*url.URL, error) { return url.Parse("http://" + ln.Addr().String()) }, }, } req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil) if err != nil { t.Fatal(err) } _, err = c.Do(req) if err == nil { t.Errorf("unexpected Get success") } // Wait unconditionally for the listener goroutine to exit: this should never // hang, so if it does we want a full goroutine dump — and that's exactly what // the testing package will give us when the test run times out. <-listenerDone } // Issue 16997: test transport dial preserves typed errors func TestTransportDialPreservesNetOpProxyError(t *testing.T) { defer afterTest(t) var errDial = errors.New("some dial error") tr := &Transport{ Proxy: func(*Request) (*url.URL, error) { return url.Parse("http://proxy.fake.tld/") }, Dial: func(string, string) (net.Conn, error) { return nil, errDial }, } defer tr.CloseIdleConnections() c := &Client{Transport: tr} req, _ := NewRequest("GET", "http://fake.tld", nil) res, err := c.Do(req) if err == nil { res.Body.Close() t.Fatal("wanted a non-nil error") } uerr, ok := err.(*url.Error) if !ok { t.Fatalf("got %T, want *url.Error", err) } oe, ok := uerr.Err.(*net.OpError) if !ok { t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err) } want := &net.OpError{ Op: "proxyconnect", Net: "tcp", Err: errDial, // original error, unwrapped. } if !reflect.DeepEqual(oe, want) { t.Errorf("Got error %#v; want %#v", oe, want) } } // Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader. // // (A bug caused dialConn to instead write the per-request Proxy-Authorization // header through to the shared Header instance, introducing a data race.) func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) { setParallel(t) defer afterTest(t) proxy := httptest.NewTLSServer(NotFoundHandler()) defer proxy.Close() c := proxy.Client() tr := c.Transport.(*Transport) tr.Proxy = func(*Request) (*url.URL, error) { u, _ := url.Parse(proxy.URL) u.User = url.UserPassword("aladdin", "opensesame") return u, nil } h := tr.ProxyConnectHeader if h == nil { h = make(Header) } tr.ProxyConnectHeader = h.Clone() req, err := NewRequest("GET", "https://golang.fake.tld/", nil) if err != nil { t.Fatal(err) } _, err = c.Do(req) if err == nil { t.Errorf("unexpected Get success") } if !reflect.DeepEqual(tr.ProxyConnectHeader, h) { t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h) } } // TestTransportGzipRecursive sends a gzip quine and checks that the // client gets the same value back. This is more cute than anything, // but checks that we don't recurse forever, and checks that // Content-Encoding is removed. func TestTransportGzipRecursive(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "gzip") w.Write(rgz) })) defer ts.Close() c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } if !bytes.Equal(body, rgz) { t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x", body, rgz) } if g, e := res.Header.Get("Content-Encoding"), ""; g != e { t.Fatalf("Content-Encoding = %q; want %q", g, e) } } // golang.org/issue/7750: request fails when server replies with // a short gzip body func TestTransportGzipShort(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "gzip") w.Write([]byte{0x1f, 0x8b}) })) defer ts.Close() c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } defer res.Body.Close() _, err = io.ReadAll(res.Body) if err == nil { t.Fatal("Expect an error from reading a body.") } if err != io.ErrUnexpectedEOF { t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err) } } // Wait until number of goroutines is no greater than nmax, or time out. func waitNumGoroutine(nmax int) int { nfinal := runtime.NumGoroutine() for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- { time.Sleep(50 * time.Millisecond) runtime.GC() nfinal = runtime.NumGoroutine() } return nfinal } // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { // Not parallel: counts goroutines defer afterTest(t) const numReq = 25 gotReqCh := make(chan bool, numReq) unblockCh := make(chan bool, numReq) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { gotReqCh <- true <-unblockCh w.Header().Set("Content-Length", "0") w.WriteHeader(204) })) defer ts.Close() c := ts.Client() tr := c.Transport.(*Transport) n0 := runtime.NumGoroutine() didReqCh := make(chan bool, numReq) failed := make(chan bool, numReq) for i := 0; i < numReq; i++ { go func() { res, err := c.Get(ts.URL) didReqCh <- true if err != nil { t.Logf("client fetch error: %v", err) failed <- true return } res.Body.Close() }() } // Wait for all goroutines to be stuck in the Handler. for i := 0; i < numReq; i++ { select { case <-gotReqCh: // ok case <-failed: // Not great but not what we are testing: // sometimes an overloaded system will fail to make all the connections. } } nhigh := runtime.NumGoroutine() // Tell all handlers to unblock and reply. close(unblockCh) // Wait for all HTTP clients to be done. for i := 0; i < numReq; i++ { <-didReqCh } tr.CloseIdleConnections() nfinal := waitNumGoroutine(n0 + 5) growth := nfinal - n0 // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. // Previously we were leaking one per numReq. if int(growth) > 5 { t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) t.Error("too many new goroutines") } } // golang.org/issue/4531: Transport leaks goroutines when // request.ContentLength is explicitly short func TestTransportPersistConnLeakShortBody(t *testing.T) { // Not parallel: measures goroutines. defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) defer ts.Close() c := ts.Client() tr := c.Transport.(*Transport) n0 := runtime.NumGoroutine() body := []byte("Hello") for i := 0; i < 20; i++ { req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) if err != nil { t.Fatal(err) } req.ContentLength = int64(len(body) - 2) // explicitly short _, err = c.Do(req) if err == nil { t.Fatal("Expect an error from writing too long of a body.") } } nhigh := runtime.NumGoroutine() tr.CloseIdleConnections() nfinal := waitNumGoroutine(n0 + 5) growth := nfinal - n0 // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. // Previously we were leaking one per numReq. t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) if int(growth) > 5 { t.Error("too many new goroutines") } } // A countedConn is a net.Conn that decrements an atomic counter when finalized. type countedConn struct { net.Conn } // A countingDialer dials connections and counts the number that remain reachable. type countingDialer struct { dialer net.Dialer mu sync.Mutex total, live int64 } func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { conn, err := d.dialer.DialContext(ctx, network, address) if err != nil { return nil, err } counted := new(countedConn) counted.Conn = conn d.mu.Lock() defer d.mu.Unlock() d.total++ d.live++ runtime.SetFinalizer(counted, d.decrement) return counted, nil } func (d *countingDialer) decrement(*countedConn) { d.mu.Lock() defer d.mu.Unlock() d.live-- } func (d *countingDialer) Read() (total, live int64) { d.mu.Lock() defer d.mu.Unlock() return d.total, d.live } func TestTransportPersistConnLeakNeverIdle(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { // Close every connection so that it cannot be kept alive. conn, _, err := w.(Hijacker).Hijack() if err != nil { t.Errorf("Hijack failed unexpectedly: %v", err) return } conn.Close() })) defer ts.Close() var d countingDialer c := ts.Client() c.Transport.(*Transport).DialContext = d.DialContext body := []byte("Hello") for i := 0; ; i++ { total, live := d.Read() if live < total { break } if i >= 1<<12 { t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i) } req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) if err != nil { t.Fatal(err) } _, err = c.Do(req) if err == nil { t.Fatal("expected broken connection") } runtime.GC() } } type countedContext struct { context.Context } type contextCounter struct { mu sync.Mutex live int64 } func (cc *contextCounter) Track(ctx context.Context) context.Context { counted := new(countedContext) counted.Context = ctx cc.mu.Lock() defer cc.mu.Unlock() cc.live++ runtime.SetFinalizer(counted, cc.decrement) return counted } func (cc *contextCounter) decrement(*countedContext) { cc.mu.Lock() defer cc.mu.Unlock() cc.live-- } func (cc *contextCounter) Read() (live int64) { cc.mu.Lock() defer cc.mu.Unlock() return cc.live } func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { runtime.Gosched() w.WriteHeader(StatusOK) })) defer ts.Close() c := ts.Client() c.Transport.(*Transport).MaxConnsPerHost = 1 ctx := context.Background() body := []byte("Hello") doPosts := func(cc *contextCounter) { var wg sync.WaitGroup for n := 64; n > 0; n-- { wg.Add(1) go func() { defer wg.Done() ctx := cc.Track(ctx) req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) if err != nil { t.Error(err) } _, err = c.Do(req.WithContext(ctx)) if err != nil { t.Errorf("Do failed with error: %v", err) } }() } wg.Wait() } var initialCC contextCounter doPosts(&initialCC) // flushCC exists only to put pressure on the GC to finalize the initialCC // contexts: the flushCC allocations should eventually displace the initialCC // allocations. var flushCC contextCounter for i := 0; ; i++ { live := initialCC.Read() if live == 0 { break } if i >= 100 { t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i) } doPosts(&flushCC) runtime.GC() } } // This used to crash; https://golang.org/issue/3266 func TestTransportIdleConnCrash(t *testing.T) { defer afterTest(t) var tr *Transport unblockCh := make(chan bool, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockCh tr.CloseIdleConnections() })) defer ts.Close() c := ts.Client() tr = c.Transport.(*Transport) didreq := make(chan bool) go func() { res, err := c.Get(ts.URL) if err != nil { t.Error(err) } else { res.Body.Close() // returns idle conn } didreq <- true }() unblockCh <- true <-didreq } // Test that the transport doesn't close the TCP connection early, // before the response body has been read. This was a regression // which sadly lacked a triggering test. The large response body made // the old race easier to trigger. func TestIssue3644(t *testing.T) { defer afterTest(t) const numFoos = 5000 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") for i := 0; i < numFoos; i++ { w.Write([]byte("foo ")) } })) defer ts.Close() c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } defer res.Body.Close() bs, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } if len(bs) != numFoos*len("foo ") { t.Errorf("unexpected response length") } } // Test that a client receives a server's reply, even if the server doesn't read // the entire request body. func TestIssue3595(t *testing.T) { setParallel(t) defer afterTest(t) const deniedMsg = "sorry, denied." ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, deniedMsg, StatusUnauthorized) })) defer ts.Close() c := ts.Client() res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) if err != nil { t.Errorf("Post: %v", err) return } got, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("Body ReadAll: %v", err) } if !strings.Contains(string(got), deniedMsg) { t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg) } } // From https://golang.org/issue/4454 , // "client fails to handle requests with no body and chunked encoding" func TestChunkedNoContent(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNoContent) })) defer ts.Close() c := ts.Client() for _, closeBody := range []bool{true, false} { const n = 4 for i := 1; i <= n; i++ { res, err := c.Get(ts.URL) if err != nil { t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err) } else { if closeBody { res.Body.Close() } } } } } func TestTransportConcurrency(t *testing.T) { // Not parallel: uses global test hooks. defer afterTest(t) maxProcs, numReqs := 16, 500 if testing.Short() { maxProcs, numReqs = 4, 50 } defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%v", r.FormValue("echo")) })) defer ts.Close() var wg sync.WaitGroup wg.Add(numReqs) // Due to the Transport's "socket late binding" (see // idleConnCh in transport.go), the numReqs HTTP requests // below can finish with a dial still outstanding. To keep // the leak checker happy, keep track of pending dials and // wait for them to finish (and be closed or returned to the // idle pool) before we close idle connections. SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) defer SetPendingDialHooks(nil, nil) c := ts.Client() reqs := make(chan string) defer close(reqs) for i := 0; i < maxProcs*2; i++ { go func() { for req := range reqs { res, err := c.Get(ts.URL + "/?echo=" + req) if err != nil { t.Errorf("error on req %s: %v", req, err) wg.Done() continue } all, err := io.ReadAll(res.Body) if err != nil { t.Errorf("read error on req %s: %v", req, err) wg.Done() continue } if string(all) != req { t.Errorf("body of req %s = %q; want %q", req, all, req) } res.Body.Close() wg.Done() } }() } for i := 0; i < numReqs; i++ { reqs <- fmt.Sprintf("request-%d", i) } wg.Wait() } func TestIssue4191_InfiniteGetTimeout(t *testing.T) { setParallel(t) defer afterTest(t) const debug = false mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { io.Copy(w, neverEnding('a')) }) ts := httptest.NewServer(mux) defer ts.Close() timeout := 100 * time.Millisecond c := ts.Client() c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { conn, err := net.Dial(n, addr) if err != nil { return nil, err } conn.SetDeadline(time.Now().Add(timeout)) if debug { conn = NewLoggingConn("client", conn) } return conn, nil } getFailed := false nRuns := 5 if testing.Short() { nRuns = 1 } for i := 0; i < nRuns; i++ { if debug { println("run", i+1, "of", nRuns) } sres, err := c.Get(ts.URL + "/get") if err != nil { if !getFailed { // Make the timeout longer, once. getFailed = true t.Logf("increasing timeout") i-- timeout *= 10 continue } t.Errorf("Error issuing GET: %v", err) break } _, err = io.Copy(io.Discard, sres.Body) if err == nil { t.Errorf("Unexpected successful copy") break } } if debug { println("tests complete; waiting for handlers to finish") } } func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { setParallel(t) defer afterTest(t) const debug = false mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { io.Copy(w, neverEnding('a')) }) mux.HandleFunc("/put", func(w ResponseWriter, r *Request) { defer r.Body.Close() io.Copy(io.Discard, r.Body) }) ts := httptest.NewServer(mux) timeout := 100 * time.Millisecond c := ts.Client() c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { conn, err := net.Dial(n, addr) if err != nil { return nil, err } conn.SetDeadline(time.Now().Add(timeout)) if debug { conn = NewLoggingConn("client", conn) } return conn, nil } getFailed := false nRuns := 5 if testing.Short() { nRuns = 1 } for i := 0; i < nRuns; i++ { if debug { println("run", i+1, "of", nRuns) } sres, err := c.Get(ts.URL + "/get") if err != nil { if !getFailed { // Make the timeout longer, once. getFailed = true t.Logf("increasing timeout") i-- timeout *= 10 continue } t.Errorf("Error issuing GET: %v", err) break } req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body) _, err = c.Do(req) if err == nil { sres.Body.Close() t.Errorf("Unexpected successful PUT") break } sres.Body.Close() } if debug { println("tests complete; waiting for handlers to finish") } ts.Close() } func TestTransportResponseHeaderTimeout(t *testing.T) { setParallel(t) defer afterTest(t) if testing.Short() { t.Skip("skipping timeout test in -short mode") } inHandler := make(chan bool, 1) mux := NewServeMux() mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) { inHandler <- true }) mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { inHandler <- true time.Sleep(2 * time.Second) }) ts := httptest.NewServer(mux) defer ts.Close() c := ts.Client() c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond tests := []struct { path string want int wantErr string }{ {path: "/fast", want: 200}, {path: "/slow", wantErr: "timeout awaiting response headers"}, {path: "/fast", want: 200}, } for i, tt := range tests { req, _ := NewRequest("GET", ts.URL+tt.path, nil) req = req.WithT(t) res, err := c.Do(req) select { case <-inHandler: case <-time.After(5 * time.Second): t.Errorf("never entered handler for test index %d, %s", i, tt.path) continue } if err != nil { uerr, ok := err.(*url.Error) if !ok { t.Errorf("error is not an url.Error; got: %#v", err) continue } nerr, ok := uerr.Err.(net.Error) if !ok { t.Errorf("error does not satisfy net.Error interface; got: %#v", err) continue } if !nerr.Timeout() { t.Errorf("want timeout error; got: %q", nerr) continue } if strings.Contains(err.Error(), tt.wantErr) { continue } t.Errorf("%d. unexpected error: %v", i, err) continue } if tt.wantErr != "" { t.Errorf("%d. no error. expected error: %v", i, tt.wantErr) continue } if res.StatusCode != tt.want { t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want) } } } func TestTransportCancelRequest(t *testing.T) { setParallel(t) defer afterTest(t) if testing.Short() { t.Skip("skipping test in -short mode") } unblockc := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello") w.(Flusher).Flush() // send headers and some body <-unblockc })) defer ts.Close() defer close(unblockc) c := ts.Client() tr := c.Transport.(*Transport) req, _ := NewRequest("GET", ts.URL, nil) res, err := c.Do(req) if err != nil { t.Fatal(err) } go func() { time.Sleep(1 * time.Second) tr.CancelRequest(req) }() t0 := time.Now() body, err := io.ReadAll(res.Body) d := time.Since(t0) if err != ExportErrRequestCanceled { t.Errorf("Body.Read error = %v; want errRequestCanceled", err) } if string(body) != "Hello" { t.Errorf("Body = %q; want Hello", body) } if d < 500*time.Millisecond { t.Errorf("expected ~1 second delay; got %v", d) } // Verify no outstanding requests after readLoop/writeLoop // goroutines shut down. for tries := 5; tries > 0; tries-- { n := tr.NumPendingRequestsForTesting() if n == 0 { break } time.Sleep(100 * time.Millisecond) if tries == 1 { t.Errorf("pending requests = %d; want 0", n) } } } func testTransportCancelRequestInDo(t *testing.T, body io.Reader) { setParallel(t) defer afterTest(t) if testing.Short() { t.Skip("skipping test in -short mode") } unblockc := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockc })) defer ts.Close() defer close(unblockc) c := ts.Client() tr := c.Transport.(*Transport) donec := make(chan bool) req, _ := NewRequest("GET", ts.URL, body) go func() { defer close(donec) c.Do(req) }() start := time.Now() timeout := 10 * time.Second for time.Since(start) < timeout { time.Sleep(100 * time.Millisecond) tr.CancelRequest(req) select { case <-donec: return default: } } t.Errorf("Do of canceled request has not returned after %v", timeout) } func TestTransportCancelRequestInDo(t *testing.T) { testTransportCancelRequestInDo(t, nil) } func TestTransportCancelRequestWithBodyInDo(t *testing.T) { testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0})) } func TestTransportCancelRequestInDial(t *testing.T) { defer afterTest(t) if testing.Short() { t.Skip("skipping test in -short mode") } var logbuf bytes.Buffer eventLog := log.New(&logbuf, "", 0) unblockDial := make(chan bool) defer close(unblockDial) inDial := make(chan bool) tr := &Transport{ Dial: func(network, addr string) (net.Conn, error) { eventLog.Println("dial: blocking") if !<-inDial { return nil, errors.New("main Test goroutine exited") } <-unblockDial return nil, errors.New("nope") }, } cl := &Client{Transport: tr} gotres := make(chan bool) req, _ := NewRequest("GET", "http://something.no-network.tld/", nil) go func() { _, err := cl.Do(req) eventLog.Printf("Get = %v", err) gotres <- true }() select { case inDial <- true: case <-time.After(5 * time.Second): close(inDial) t.Fatal("timeout; never saw blocking dial") } eventLog.Printf("canceling") tr.CancelRequest(req) tr.CancelRequest(req) // used to panic on second call select { case <-gotres: case <-time.After(5 * time.Second): panic("hang. events are: " + logbuf.String()) } got := logbuf.String() want := `dial: blocking canceling Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection ` if got != want { t.Errorf("Got events:\n%s\nWant:\n%s", got, want) } } func TestCancelRequestWithChannel(t *testing.T) { setParallel(t) defer afterTest(t) if testing.Short() { t.Skip("skipping test in -short mode") } unblockc := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello") w.(Flusher).Flush() // send headers and some body <-unblockc })) defer ts.Close() defer close(unblockc) c := ts.Client() tr := c.Transport.(*Transport) req, _ := NewRequest("GET", ts.URL, nil) ch := make(chan struct{}) req.Cancel = ch res, err := c.Do(req) if err != nil { t.Fatal(err) } go func() { time.Sleep(1 * time.Second) close(ch) }() t0 := time.Now() body, err := io.ReadAll(res.Body) d := time.Since(t0) if err != ExportErrRequestCanceled { t.Errorf("Body.Read error = %v; want errRequestCanceled", err) } if string(body) != "Hello" { t.Errorf("Body = %q; want Hello", body) } if d < 500*time.Millisecond { t.Errorf("expected ~1 second delay; got %v", d) } // Verify no outstanding requests after readLoop/writeLoop // goroutines shut down. for tries := 5; tries > 0; tries-- { n := tr.NumPendingRequestsForTesting() if n == 0 { break } time.Sleep(100 * time.Millisecond) if tries == 1 { t.Errorf("pending requests = %d; want 0", n) } } } func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) { testCancelRequestWithChannelBeforeDo(t, false) } func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) { testCancelRequestWithChannelBeforeDo(t, true) } func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { setParallel(t) defer afterTest(t) unblockc := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockc })) defer ts.Close() defer close(unblockc) c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) if withCtx { ctx, cancel := context.WithCancel(context.Background()) cancel() req = req.WithContext(ctx) } else { ch := make(chan struct{}) req.Cancel = ch close(ch) } _, err := c.Do(req) if ue, ok := err.(*url.Error); ok { err = ue.Err } if withCtx { if err != context.Canceled { t.Errorf("Do error = %v; want %v", err, context.Canceled) } } else { if err == nil || !strings.Contains(err.Error(), "canceled") { t.Errorf("Do error = %v; want cancellation", err) } } } // Issue 11020. The returned error message should be errRequestCanceled func TestTransportCancelBeforeResponseHeaders(t *testing.T) { defer afterTest(t) serverConnCh := make(chan net.Conn, 1) tr := &Transport{ Dial: func(network, addr string) (net.Conn, error) { cc, sc := net.Pipe() serverConnCh <- sc return cc, nil }, } defer tr.CloseIdleConnections() errc := make(chan error, 1) req, _ := NewRequest("GET", "http://example.com/", nil) go func() { _, err := tr.RoundTrip(req) errc <- err }() sc := <-serverConnCh verb := make([]byte, 3) if _, err := io.ReadFull(sc, verb); err != nil { t.Errorf("Error reading HTTP verb from server: %v", err) } if string(verb) != "GET" { t.Errorf("server received %q; want GET", verb) } defer sc.Close() tr.CancelRequest(req) err := <-errc if err == nil { t.Fatalf("unexpected success from RoundTrip") } if err != ExportErrRequestCanceled { t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err) } } // golang.org/issue/3672 -- Client can't close HTTP stream // Calling Close on a Response.Body used to just read until EOF. // Now it actually closes the TCP connection. func TestTransportCloseResponseBody(t *testing.T) { defer afterTest(t) writeErr := make(chan error, 1) msg := []byte("young\n") ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { for { _, err := w.Write(msg) if err != nil { writeErr <- err return } w.(Flusher).Flush() } })) defer ts.Close() c := ts.Client() tr := c.Transport.(*Transport) req, _ := NewRequest("GET", ts.URL, nil) defer tr.CancelRequest(req) res, err := c.Do(req) if err != nil { t.Fatal(err) } const repeats = 3 buf := make([]byte, len(msg)*repeats) want := bytes.Repeat(msg, repeats) _, err = io.ReadFull(res.Body, buf) if err != nil { t.Fatal(err) } if !bytes.Equal(buf, want) { t.Fatalf("read %q; want %q", buf, want) } didClose := make(chan error, 1) go func() { didClose <- res.Body.Close() }() select { case err := <-didClose: if err != nil { t.Errorf("Close = %v", err) } case <-time.After(10 * time.Second): t.Fatal("too long waiting for close") } select { case err := <-writeErr: if err == nil { t.Errorf("expected non-nil write error") } case <-time.After(10 * time.Second): t.Fatal("too long waiting for write error") } } type fooProto struct{} func (fooProto) RoundTrip(req *Request) (*Response, error) { res := &Response{ Status: "200 OK", StatusCode: 200, Header: make(Header), Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())), } return res, nil } func TestTransportAltProto(t *testing.T) { defer afterTest(t) tr := &Transport{} c := &Client{Transport: tr} tr.RegisterProtocol("foo", fooProto{}) res, err := c.Get("foo://bar.com/path") if err != nil { t.Fatal(err) } bodyb, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } body := string(bodyb) if e := "You wanted foo://bar.com/path"; body != e { t.Errorf("got response %q, want %q", body, e) } } func TestTransportNoHost(t *testing.T) { defer afterTest(t) tr := &Transport{} _, err := tr.RoundTrip(&Request{ Header: make(Header), URL: &url.URL{ Scheme: "http", }, }) want := "http: no Host in request URL" if got := fmt.Sprint(err); got != want { t.Errorf("error = %v; want %q", err, want) } } // Issue 13311 func TestTransportEmptyMethod(t *testing.T) { req, _ := NewRequest("GET", "http://foo.com/", nil) req.Method = "" // docs say "For client requests an empty string means GET" got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport if err != nil { t.Fatal(err) } if !strings.Contains(string(got), "GET ") { t.Fatalf("expected substring 'GET '; got: %s", got) } } func TestTransportSocketLateBinding(t *testing.T) { setParallel(t) defer afterTest(t) mux := NewServeMux() fooGate := make(chan bool, 1) mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) { w.Header().Set("foo-ipport", r.RemoteAddr) w.(Flusher).Flush() <-fooGate }) mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) { w.Header().Set("bar-ipport", r.RemoteAddr) }) ts := httptest.NewServer(mux) defer ts.Close() dialGate := make(chan bool, 1) c := ts.Client() c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { if <-dialGate { return net.Dial(n, addr) } return nil, errors.New("manually closed") } dialGate <- true // only allow one dial fooRes, err := c.Get(ts.URL + "/foo") if err != nil { t.Fatal(err) } fooAddr := fooRes.Header.Get("foo-ipport") if fooAddr == "" { t.Fatal("No addr on /foo request") } time.AfterFunc(200*time.Millisecond, func() { // let the foo response finish so we can use its // connection for /bar fooGate <- true io.Copy(io.Discard, fooRes.Body) fooRes.Body.Close() }) barRes, err := c.Get(ts.URL + "/bar") if err != nil { t.Fatal(err) } barAddr := barRes.Header.Get("bar-ipport") if barAddr != fooAddr { t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr) } barRes.Body.Close() dialGate <- false } // Issue 2184 func TestTransportReading100Continue(t *testing.T) { defer afterTest(t) const numReqs = 5 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) } reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) } send100Response := func(w *io.PipeWriter, r *io.PipeReader) { defer w.Close() defer r.Close() br := bufio.NewReader(r) n := 0 for { n++ req, err := ReadRequest(br) if err == io.EOF { return } if err != nil { t.Error(err) return } slurp, err := io.ReadAll(req.Body) if err != nil { t.Errorf("Server request body slurp: %v", err) return } id := req.Header.Get("Request-Id") resCode := req.Header.Get("X-Want-Response-Code") if resCode == "" { resCode = "100 Continue" if string(slurp) != reqBody(n) { t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n)) } } body := fmt.Sprintf("Response number %d", n) v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s Date: Thu, 28 Feb 2013 17:55:41 GMT HTTP/1.1 200 OK Content-Type: text/html Echo-Request-Id: %s Content-Length: %d %s`, resCode, id, len(body), body), "\n", "\r\n", -1)) w.Write(v) if id == reqID(numReqs) { return } } } tr := &Transport{ Dial: func(n, addr string) (net.Conn, error) { sr, sw := io.Pipe() // server read/write cr, cw := io.Pipe() // client read/write conn := &rwTestConn{ Reader: cr, Writer: sw, closeFunc: func() error { sw.Close() cw.Close() return nil }, } go send100Response(cw, sr) return conn, nil }, DisableKeepAlives: false, } defer tr.CloseIdleConnections() c := &Client{Transport: tr} testResponse := func(req *Request, name string, wantCode int) { t.Helper() res, err := c.Do(req) if err != nil { t.Fatalf("%s: Do: %v", name, err) } if res.StatusCode != wantCode { t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode) } if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack { t.Errorf("%s: response id %q != request id %q", name, idBack, id) } _, err = io.ReadAll(res.Body) if err != nil { t.Fatalf("%s: Slurp error: %v", name, err) } } // Few 100 responses, making sure we're not off-by-one. for i := 1; i <= numReqs; i++ { req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i))) req.Header.Set("Request-Id", reqID(i)) testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200) } } // Issue 17739: the HTTP client must ignore any unknown 1xx // informational responses before the actual response. func TestTransportIgnore1xxResponses(t *testing.T) { setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello")) buf.Flush() conn.Close() })) defer cst.close() cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway var got bytes.Buffer req, _ := NewRequest("GET", cst.ts.URL, nil) req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ Got1xxResponse: func(code int, header textproto.MIMEHeader) error { fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header) return nil }, })) res, err := cst.c.Do(req) if err != nil { t.Fatal(err) } defer res.Body.Close() res.Write(&got) want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello" if got.String() != want { t.Errorf(" got: %q\nwant: %q\n", got.Bytes(), want) } } func TestTransportLimits1xxResponses(t *testing.T) { setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() for i := 0; i < 10; i++ { buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n")) } buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) buf.Flush() conn.Close() })) defer cst.close() cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway res, err := cst.c.Get(cst.ts.URL) if res != nil { defer res.Body.Close() } got := fmt.Sprint(err) wantSub := "too many 1xx informational responses" if !strings.Contains(got, wantSub) { t.Errorf("Get error = %v; want substring %q", err, wantSub) } } // Issue 26161: the HTTP client must treat 101 responses // as the final response. func TestTransportTreat101Terminal(t *testing.T) { setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n")) buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) buf.Flush() conn.Close() })) defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } defer res.Body.Close() if res.StatusCode != StatusSwitchingProtocols { t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode) } } type proxyFromEnvTest struct { req string // URL to fetch; blank means "http://example.com" env string // HTTP_PROXY httpsenv string // HTTPS_PROXY noenv string // NO_PROXY reqmeth string // REQUEST_METHOD want string wanterr error } func (t proxyFromEnvTest) String() string { var buf bytes.Buffer space := func() { if buf.Len() > 0 { buf.WriteByte(' ') } } if t.env != "" { fmt.Fprintf(&buf, "http_proxy=%q", t.env) } if t.httpsenv != "" { space() fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv) } if t.noenv != "" { space() fmt.Fprintf(&buf, "no_proxy=%q", t.noenv) } if t.reqmeth != "" { space() fmt.Fprintf(&buf, "request_method=%q", t.reqmeth) } req := "http://example.com" if t.req != "" { req = t.req } space() fmt.Fprintf(&buf, "req=%q", req) return strings.TrimSpace(buf.String()) } var proxyFromEnvTests = []proxyFromEnvTest{ {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"}, {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"}, {env: "cache.corp.example.com", want: "http://cache.corp.example.com"}, {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"}, {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"}, {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"}, {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"}, // Don't use secure for http {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"}, // Use secure for https. {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"}, {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"}, // Issue 16405: don't use HTTP_PROXY in a CGI environment, // where HTTP_PROXY can be attacker-controlled. {env: "http://10.1.2.3:8080", reqmeth: "POST", want: "", wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")}, {want: ""}, {noenv: "example.com", req: "http://example.com/", env: "proxy", want: ""}, {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: ""}, {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, } func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) { t.Helper() reqURL := tt.req if reqURL == "" { reqURL = "http://example.com" } req, _ := NewRequest("GET", reqURL, nil) url, err := proxyForRequest(req) if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e { t.Errorf("%v: got error = %q, want %q", tt, g, e) return } if got := fmt.Sprintf("%s", url); got != tt.want { t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want) } } func TestProxyFromEnvironment(t *testing.T) { ResetProxyEnv() defer ResetProxyEnv() for _, tt := range proxyFromEnvTests { testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) { os.Setenv("HTTP_PROXY", tt.env) os.Setenv("HTTPS_PROXY", tt.httpsenv) os.Setenv("NO_PROXY", tt.noenv) os.Setenv("REQUEST_METHOD", tt.reqmeth) ResetCachedEnvironment() return ProxyFromEnvironment(req) }) } } func TestProxyFromEnvironmentLowerCase(t *testing.T) { ResetProxyEnv() defer ResetProxyEnv() for _, tt := range proxyFromEnvTests { testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) { os.Setenv("http_proxy", tt.env) os.Setenv("https_proxy", tt.httpsenv) os.Setenv("no_proxy", tt.noenv) os.Setenv("REQUEST_METHOD", tt.reqmeth) ResetCachedEnvironment() return ProxyFromEnvironment(req) }) } } func TestIdleConnChannelLeak(t *testing.T) { // Not parallel: uses global test hooks. var mu sync.Mutex var n int ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() n++ mu.Unlock() })) defer ts.Close() const nReqs = 5 didRead := make(chan bool, nReqs) SetReadLoopBeforeNextReadHook(func() { didRead <- true }) defer SetReadLoopBeforeNextReadHook(nil) c := ts.Client() tr := c.Transport.(*Transport) tr.Dial = func(netw, addr string) (net.Conn, error) { return net.Dial(netw, ts.Listener.Addr().String()) } // First, without keep-alives. for _, disableKeep := range []bool{true, false} { tr.DisableKeepAlives = disableKeep for i := 0; i < nReqs; i++ { _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i)) if err != nil { t.Fatal(err) } // Note: no res.Body.Close is needed here, since the // response Content-Length is zero. Perhaps the test // should be more explicit and use a HEAD, but tests // elsewhere guarantee that zero byte responses generate // a "Content-Length: 0" instead of chunking. } // At this point, each of the 5 Transport.readLoop goroutines // are scheduling noting that there are no response bodies (see // earlier comment), and are then calling putIdleConn, which // decrements this count. Usually that happens quickly, which is // why this test has seemed to work for ages. But it's still // racey: we have wait for them to finish first. See Issue 10427 for i := 0; i < nReqs; i++ { <-didRead } if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 { t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got) } } } // Verify the status quo: that the Client.Post function coerces its // body into a ReadCloser if it's a Closer, and that the Transport // then closes it. func TestTransportClosesRequestBody(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { io.Copy(io.Discard, r.Body) })) defer ts.Close() c := ts.Client() closes := 0 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) if err != nil { t.Fatal(err) } res.Body.Close() if closes != 1 { t.Errorf("closes = %d; want 1", closes) } } func TestTransportTLSHandshakeTimeout(t *testing.T) { defer afterTest(t) if testing.Short() { t.Skip("skipping in short mode") } ln := newLocalListener(t) defer ln.Close() testdonec := make(chan struct{}) defer close(testdonec) go func() { c, err := ln.Accept() if err != nil { t.Error(err) return } <-testdonec c.Close() }() getdonec := make(chan struct{}) go func() { defer close(getdonec) tr := &Transport{ Dial: func(_, _ string) (net.Conn, error) { return net.Dial("tcp", ln.Addr().String()) }, TLSHandshakeTimeout: 250 * time.Millisecond, } cl := &Client{Transport: tr} _, err := cl.Get("https://dummy.tld/") if err == nil { t.Error("expected error") return } ue, ok := err.(*url.Error) if !ok { t.Errorf("expected url.Error; got %#v", err) return } ne, ok := ue.Err.(net.Error) if !ok { t.Errorf("expected net.Error; got %#v", err) return } if !ne.Timeout() { t.Errorf("expected timeout error; got %v", err) } if !strings.Contains(err.Error(), "handshake timeout") { t.Errorf("expected 'handshake timeout' in error; got %v", err) } }() select { case <-getdonec: case <-time.After(5 * time.Second): t.Error("test timeout; TLS handshake hung?") } } // Trying to repro golang.org/issue/3514 func TestTLSServerClosesConnection(t *testing.T) { defer afterTest(t) closedc := make(chan bool, 1) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { if strings.Contains(r.URL.Path, "/keep-alive-then-die") { conn, _, _ := w.(Hijacker).Hijack() conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) conn.Close() closedc <- true return } fmt.Fprintf(w, "hello") })) defer ts.Close() c := ts.Client() tr := c.Transport.(*Transport) var nSuccess = 0 var errs []error const trials = 20 for i := 0; i < trials; i++ { tr.CloseIdleConnections() res, err := c.Get(ts.URL + "/keep-alive-then-die") if err != nil { t.Fatal(err) } <-closedc slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } if string(slurp) != "foo" { t.Errorf("Got %q, want foo", slurp) } // Now try again and see if we successfully // pick a new connection. res, err = c.Get(ts.URL + "/") if err != nil { errs = append(errs, err) continue } slurp, err = io.ReadAll(res.Body) if err != nil { errs = append(errs, err) continue } nSuccess++ } if nSuccess > 0 { t.Logf("successes = %d of %d", nSuccess, trials) } else { t.Errorf("All runs failed:") } for _, err := range errs { t.Logf(" err: %v", err) } } // byteFromChanReader is an io.Reader that reads a single byte at a // time from the channel. When the channel is closed, the reader // returns io.EOF. type byteFromChanReader chan byte func (c byteFromChanReader) Read(p []byte) (n int, err error) { if len(p) == 0 { return } b, ok := <-c if !ok { return 0, io.EOF } p[0] = b return 1, nil } // Verifies that the Transport doesn't reuse a connection in the case // where the server replies before the request has been fully // written. We still honor that reply (see TestIssue3595), but don't // send future requests on the connection because it's then in a // questionable state. // golang.org/issue/7569 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { setParallel(t) defer afterTest(t) var sconn struct { sync.Mutex c net.Conn } var getOkay bool closeConn := func() { sconn.Lock() defer sconn.Unlock() if sconn.c != nil { sconn.c.Close() sconn.c = nil if !getOkay { t.Logf("Closed server connection") } } } defer closeConn() ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method == "GET" { io.WriteString(w, "bar") return } conn, _, _ := w.(Hijacker).Hijack() sconn.Lock() sconn.c = conn sconn.Unlock() conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive go io.Copy(io.Discard, conn) })) defer ts.Close() c := ts.Client() const bodySize = 256 << 10 finalBit := make(byteFromChanReader, 1) req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit)) req.ContentLength = bodySize res, err := c.Do(req) if err := wantBody(res, err, "foo"); err != nil { t.Errorf("POST response: %v", err) } donec := make(chan bool) go func() { defer close(donec) res, err = c.Get(ts.URL) if err := wantBody(res, err, "bar"); err != nil { t.Errorf("GET response: %v", err) return } getOkay = true // suppress test noise }() time.AfterFunc(5*time.Second, closeConn) select { case <-donec: finalBit <- 'x' // unblock the writeloop of the first Post close(finalBit) case <-time.After(7 * time.Second): t.Fatal("timeout waiting for GET request to finish") } } // Tests that we don't leak Transport persistConn.readLoop goroutines // when a server hangs up immediately after saying it would keep-alive. func TestTransportIssue10457(t *testing.T) { defer afterTest(t) // used to fail in goroutine leak check ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { // Send a response with no body, keep-alive // (implicit), and then lie and immediately close the // connection. This forces the Transport's readLoop to // immediately Peek an io.EOF and get to the point // that used to hang. conn, _, _ := w.(Hijacker).Hijack() conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive conn.Close() })) defer ts.Close() c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatalf("Get: %v", err) } defer res.Body.Close() // Just a sanity check that we at least get the response. The real // test here is that the "defer afterTest" above doesn't find any // leaked goroutines. if got, want := res.Header.Get("Foo"), "Bar"; got != want { t.Errorf("Foo header = %q; want %q", got, want) } } type closerFunc func() error func (f closerFunc) Close() error { return f() } type writerFuncConn struct { net.Conn write func(p []byte) (n int, err error) } func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) } // Issues 4677, 18241, and 17844. If we try to reuse a connection that the // server is in the process of closing, we may end up successfully writing out // our request (or a portion of our request) only to find a connection error // when we try to read from (or finish writing to) the socket. // // NOTE: we resend a request only if: // - we reused a keep-alive connection // - we haven't yet received any header data // - either we wrote no bytes to the server, or the request is idempotent // This automatically prevents an infinite resend loop because we'll run out of // the cached keep-alive connections eventually. func TestRetryRequestsOnError(t *testing.T) { newRequest := func(method, urlStr string, body io.Reader) *Request { req, err := NewRequest(method, urlStr, body) if err != nil { t.Fatal(err) } return req } testCases := []struct { name string failureN int failureErr error // Note that we can't just re-use the Request object across calls to c.Do // because we need to rewind Body between calls. (GetBody is only used to // rewind Body on failure and redirects, not just because it's done.) req func() *Request reqString string }{ { name: "IdempotentNoBodySomeWritten", // Believe that we've written some bytes to the server, so we know we're // not just in the "retry when no bytes sent" case". failureN: 1, // Use the specific error that shouldRetryRequest looks for with idempotent requests. failureErr: ExportErrServerClosedIdle, req: func() *Request { return newRequest("GET", "http://fake.golang", nil) }, reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`, }, { name: "IdempotentGetBodySomeWritten", // Believe that we've written some bytes to the server, so we know we're // not just in the "retry when no bytes sent" case". failureN: 1, // Use the specific error that shouldRetryRequest looks for with idempotent requests. failureErr: ExportErrServerClosedIdle, req: func() *Request { return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n")) }, reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`, }, { name: "NothingWrittenNoBody", // It's key that we return 0 here -- that's what enables Transport to know // that nothing was written, even though this is a non-idempotent request. failureN: 0, failureErr: errors.New("second write fails"), req: func() *Request { return newRequest("DELETE", "http://fake.golang", nil) }, reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`, }, { name: "NothingWrittenGetBody", // It's key that we return 0 here -- that's what enables Transport to know // that nothing was written, even though this is a non-idempotent request. failureN: 0, failureErr: errors.New("second write fails"), // Note that NewRequest will set up GetBody for strings.Reader, which is // required for the retry to occur req: func() *Request { return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n")) }, reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { defer afterTest(t) var ( mu sync.Mutex logbuf bytes.Buffer ) logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&logbuf, format, args...) logbuf.WriteByte('\n') } ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { logf("Handler") w.Header().Set("X-Status", "ok") })) defer ts.Close() var writeNumAtomic int32 c := ts.Client() c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) { logf("Dial") c, err := net.Dial(network, ts.Listener.Addr().String()) if err != nil { logf("Dial error: %v", err) return nil, err } return &writerFuncConn{ Conn: c, write: func(p []byte) (n int, err error) { if atomic.AddInt32(&writeNumAtomic, 1) == 2 { logf("intentional write failure") return tc.failureN, tc.failureErr } logf("Write(%q)", p) return c.Write(p) }, }, nil } SetRoundTripRetried(func() { logf("Retried.") }) defer SetRoundTripRetried(nil) for i := 0; i < 3; i++ { t0 := time.Now() req := tc.req() res, err := c.Do(req) if err != nil { if time.Since(t0) < MaxWriteWaitBeforeConnReuse/2 { mu.Lock() got := logbuf.String() mu.Unlock() t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got) } t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", MaxWriteWaitBeforeConnReuse) } res.Body.Close() if res.Request != req { t.Errorf("Response.Request != original request; want identical Request") } } mu.Lock() got := logbuf.String() mu.Unlock() want := fmt.Sprintf(`Dial Write("%s") Handler intentional write failure Retried. Dial Write("%s") Handler Write("%s") Handler `, tc.reqString, tc.reqString, tc.reqString) if got != want { t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want) } }) } } // Issue 6981 func TestTransportClosesBodyOnError(t *testing.T) { setParallel(t) defer afterTest(t) readBody := make(chan error, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { _, err := io.ReadAll(r.Body) readBody <- err })) defer ts.Close() c := ts.Client() fakeErr := errors.New("fake error") didClose := make(chan bool, 1) req, _ := NewRequest("POST", ts.URL, struct { io.Reader io.Closer }{ io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)), closerFunc(func() error { select { case didClose <- true: default: } return nil }), }) res, err := c.Do(req) if res != nil { defer res.Body.Close() } if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) { t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error()) } select { case err := <-readBody: if err == nil { t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'") } case <-time.After(5 * time.Second): t.Error("timeout waiting for server handler to complete") } select { case <-didClose: default: t.Errorf("didn't see Body.Close") } } func TestTransportDialTLS(t *testing.T) { setParallel(t) defer afterTest(t) var mu sync.Mutex // guards following var gotReq, didDial bool ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() gotReq = true mu.Unlock() })) defer ts.Close() c := ts.Client() c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) { mu.Lock() didDial = true mu.Unlock() c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) if err != nil { return nil, err } return c, c.Handshake() } res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } res.Body.Close() mu.Lock() if !gotReq { t.Error("didn't get request") } if !didDial { t.Error("didn't use dial hook") } } func TestTransportDialContext(t *testing.T) { setParallel(t) defer afterTest(t) var mu sync.Mutex // guards following var gotReq bool var receivedContext context.Context ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() gotReq = true mu.Unlock() })) defer ts.Close() c := ts.Client() c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { mu.Lock() receivedContext = ctx mu.Unlock() return net.Dial(netw, addr) } req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) } ctx := context.WithValue(context.Background(), "some-key", "some-value") res, err := c.Do(req.WithContext(ctx)) if err != nil { t.Fatal(err) } res.Body.Close() mu.Lock() if !gotReq { t.Error("didn't get request") } if receivedContext != ctx { t.Error("didn't receive correct context") } } func TestTransportDialTLSContext(t *testing.T) { setParallel(t) defer afterTest(t) var mu sync.Mutex // guards following var gotReq bool var receivedContext context.Context ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() gotReq = true mu.Unlock() })) defer ts.Close() c := ts.Client() c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { mu.Lock() receivedContext = ctx mu.Unlock() c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) if err != nil { return nil, err } return c, c.HandshakeContext(ctx) } req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) } ctx := context.WithValue(context.Background(), "some-key", "some-value") res, err := c.Do(req.WithContext(ctx)) if err != nil { t.Fatal(err) } res.Body.Close() mu.Lock() if !gotReq { t.Error("didn't get request") } if receivedContext != ctx { t.Error("didn't receive correct context") } } // Test for issue 8755 // Ensure that if a proxy returns an error, it is exposed by RoundTrip func TestRoundTripReturnsProxyError(t *testing.T) { badProxy := func(*Request) (*url.URL, error) { return nil, errors.New("errorMessage") } tr := &Transport{Proxy: badProxy} req, _ := NewRequest("GET", "http://example.com", nil) _, err := tr.RoundTrip(req) if err == nil { t.Error("Expected proxy error to be returned by RoundTrip") } } // tests that putting an idle conn after a call to CloseIdleConns does return it func TestTransportCloseIdleConnsThenReturn(t *testing.T) { tr := &Transport{} wantIdle := func(when string, n int) bool { got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn if got == n { return true } t.Errorf("%s: idle conns = %d; want %d", when, got, n) return false } wantIdle("start", 0) if !tr.PutIdleTestConn("http", "example.com") { t.Fatal("put failed") } if !tr.PutIdleTestConn("http", "example.com") { t.Fatal("second put failed") } wantIdle("after put", 2) tr.CloseIdleConnections() if !tr.IsIdleForTesting() { t.Error("should be idle after CloseIdleConnections") } wantIdle("after close idle", 0) if tr.PutIdleTestConn("http", "example.com") { t.Fatal("put didn't fail") } wantIdle("after second put", 0) tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode if tr.IsIdleForTesting() { t.Error("shouldn't be idle after QueueForIdleConnForTesting") } if !tr.PutIdleTestConn("http", "example.com") { t.Fatal("after re-activation") } wantIdle("after final put", 1) } // Test for issue 34282 // Ensure that getConn doesn't call the GotConn trace hook on a HTTP/2 idle conn func TestTransportTraceGotConnH2IdleConns(t *testing.T) { tr := &Transport{} wantIdle := func(when string, n int) bool { got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2 if got == n { return true } t.Errorf("%s: idle conns = %d; want %d", when, got, n) return false } wantIdle("start", 0) alt := funcRoundTripper(func() {}) if !tr.PutIdleTestConnH2("https", "example.com:443", alt) { t.Fatal("put failed") } wantIdle("after put", 1) ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ GotConn: func(httptrace.GotConnInfo) { // tr.getConn should leave it for the HTTP/2 alt to call GotConn. t.Error("GotConn called") }, }) req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil) _, err := tr.RoundTrip(req) if err != errFakeRoundTrip { t.Errorf("got error: %v; want %q", err, errFakeRoundTrip) } wantIdle("after round trip", 1) } func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } trFunc := func(tr *Transport) { tr.MaxConnsPerHost = 1 tr.MaxIdleConnsPerHost = 1 tr.IdleConnTimeout = 10 * time.Millisecond } cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc) defer cst.close() if _, err := cst.c.Get(cst.ts.URL); err != nil { t.Fatalf("got error: %s", err) } time.Sleep(100 * time.Millisecond) got := make(chan error) go func() { if _, err := cst.c.Get(cst.ts.URL); err != nil { got <- err } close(got) }() timeout := time.NewTimer(5 * time.Second) defer timeout.Stop() select { case err := <-got: if err != nil { t.Fatalf("got error: %s", err) } case <-timeout.C: t.Fatal("request never completed") } } // This tests that a client requesting a content range won't also // implicitly ask for gzip support. If they want that, they need to do it // on their own. // golang.org/issue/8923 func TestTransportRangeAndGzip(t *testing.T) { defer afterTest(t) reqc := make(chan *Request, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { reqc <- r })) defer ts.Close() c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) req.Header.Set("Range", "bytes=7-11") res, err := c.Do(req) if err != nil { t.Fatal(err) } select { case r := <-reqc: if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { t.Error("Transport advertised gzip support in the Accept header") } if r.Header.Get("Range") == "" { t.Error("no Range in request") } case <-time.After(10 * time.Second): t.Fatal("timeout") } res.Body.Close() } // Test for issue 10474 func TestTransportResponseCancelRace(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { // important that this response has a body. var b [1024]byte w.Write(b[:]) })) defer ts.Close() tr := ts.Client().Transport.(*Transport) req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) } res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } // If we do an early close, Transport just throws the connection away and // doesn't reuse it. In order to trigger the bug, it has to reuse the connection // so read the body if _, err := io.Copy(io.Discard, res.Body); err != nil { t.Fatal(err) } req2, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) } tr.CancelRequest(req) res, err = tr.RoundTrip(req2) if err != nil { t.Fatal(err) } res.Body.Close() } // Test for issue 19248: Content-Encoding's value is case insensitive. func TestTransportContentEncodingCaseInsensitive(t *testing.T) { setParallel(t) defer afterTest(t) for _, ce := range []string{"gzip", "GZIP"} { ce := ce t.Run(ce, func(t *testing.T) { const encodedString = "Hello Gopher" ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", ce) gz := gzip.NewWriter(w) gz.Write([]byte(encodedString)) gz.Close() })) defer ts.Close() res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) } if string(body) != encodedString { t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body)) } }) } } func TestTransportDialCancelRace(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() tr := ts.Client().Transport.(*Transport) req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) } SetEnterRoundTripHook(func() { tr.CancelRequest(req) }) defer SetEnterRoundTripHook(nil) res, err := tr.RoundTrip(req) if err != ExportErrRequestCanceled { t.Errorf("expected canceled request error; got %v", err) if err == nil { res.Body.Close() } } } // logWritesConn is a net.Conn that logs each Write call to writes // and then proxies to w. // It proxies Read calls to a reader it receives from rch. type logWritesConn struct { net.Conn // nil. crash on use. w io.Writer rch <-chan io.Reader r io.Reader // nil until received by rch mu sync.Mutex writes []string } func (c *logWritesConn) Write(p []byte) (n int, err error) { c.mu.Lock() defer c.mu.Unlock() c.writes = append(c.writes, string(p)) return c.w.Write(p) } func (c *logWritesConn) Read(p []byte) (n int, err error) { if c.r == nil { c.r = <-c.rch } return c.r.Read(p) } func (c *logWritesConn) Close() error { return nil } // Issue 6574 func TestTransportFlushesBodyChunks(t *testing.T) { defer afterTest(t) resBody := make(chan io.Reader, 1) connr, connw := io.Pipe() // connection pipe pair lw := &logWritesConn{ rch: resBody, w: connw, } tr := &Transport{ Dial: func(network, addr string) (net.Conn, error) { return lw, nil }, } bodyr, bodyw := io.Pipe() // body pipe pair go func() { defer bodyw.Close() for i := 0; i < 3; i++ { fmt.Fprintf(bodyw, "num%d\n", i) } }() resc := make(chan *Response) go func() { req, _ := NewRequest("POST", "http://localhost:8080", bodyr) req.Header.Set("User-Agent", "x") // known value for test res, err := tr.RoundTrip(req) if err != nil { t.Errorf("RoundTrip: %v", err) close(resc) return } resc <- res }() // Fully consume the request before checking the Write log vs. want. req, err := ReadRequest(bufio.NewReader(connr)) if err != nil { t.Fatal(err) } io.Copy(io.Discard, req.Body) // Unblock the transport's roundTrip goroutine. resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") res, ok := <-resc if !ok { return } defer res.Body.Close() want := []string{ "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n", "5\r\nnum0\n\r\n", "5\r\nnum1\n\r\n", "5\r\nnum2\n\r\n", "0\r\n\r\n", } if !reflect.DeepEqual(lw.writes, want) { t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want) } } // Issue 22088: flush Transport request headers if we're not sure the body won't block on read. func TestTransportFlushesRequestHeader(t *testing.T) { defer afterTest(t) gotReq := make(chan struct{}) cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { close(gotReq) })) defer cst.close() pr, pw := io.Pipe() req, err := NewRequest("POST", cst.ts.URL, pr) if err != nil { t.Fatal(err) } gotRes := make(chan struct{}) go func() { defer close(gotRes) res, err := cst.tr.RoundTrip(req) if err != nil { t.Error(err) return } res.Body.Close() }() select { case <-gotReq: pw.Close() case <-time.After(5 * time.Second): t.Fatal("timeout waiting for handler to get request") } <-gotRes } // Issue 11745. func TestTransportPrefersResponseOverWriteError(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } defer afterTest(t) const contentLengthLimit = 1024 * 1024 // 1MB ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.ContentLength >= contentLengthLimit { w.WriteHeader(StatusBadRequest) r.Body.Close() return } w.WriteHeader(StatusOK) })) defer ts.Close() c := ts.Client() fail := 0 count := 100 bigBody := strings.Repeat("a", contentLengthLimit*2) for i := 0; i < count; i++ { req, err := NewRequest("PUT", ts.URL, strings.NewReader(bigBody)) if err != nil { t.Fatal(err) } resp, err := c.Do(req) if err != nil { fail++ t.Logf("%d = %#v", i, err) if ue, ok := err.(*url.Error); ok { t.Logf("urlErr = %#v", ue.Err) if ne, ok := ue.Err.(*net.OpError); ok { t.Logf("netOpError = %#v", ne.Err) } } } else { resp.Body.Close() if resp.StatusCode != 400 { t.Errorf("Expected status code 400, got %v", resp.Status) } } } if fail > 0 { t.Errorf("Failed %v out of %v\n", fail, count) } } func TestTransportAutomaticHTTP2(t *testing.T) { testTransportAutoHTTP(t, &Transport{}, true) } func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) { testTransportAutoHTTP(t, &Transport{ ForceAttemptHTTP2: true, TLSClientConfig: new(tls.Config), }, true) } // golang.org/issue/14391: also check DefaultTransport func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) { testTransportAutoHTTP(t, DefaultTransport.(*Transport), true) } func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) { testTransportAutoHTTP(t, &Transport{ TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper), }, false) } func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) { testTransportAutoHTTP(t, &Transport{ TLSClientConfig: new(tls.Config), }, false) } func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) { testTransportAutoHTTP(t, &Transport{ ExpectContinueTimeout: 1 * time.Second, }, true) } func TestTransportAutomaticHTTP2_Dial(t *testing.T) { var d net.Dialer testTransportAutoHTTP(t, &Transport{ Dial: d.Dial, }, false) } func TestTransportAutomaticHTTP2_DialContext(t *testing.T) { var d net.Dialer testTransportAutoHTTP(t, &Transport{ DialContext: d.DialContext, }, false) } func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) { testTransportAutoHTTP(t, &Transport{ DialTLS: func(network, addr string) (net.Conn, error) { panic("unused") }, }, false) } func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) { CondSkipHTTP2(t) _, err := tr.RoundTrip(new(Request)) if err == nil { t.Error("expected error from RoundTrip") } if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 { t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2) } } // Issue 13633: there was a race where we returned bodyless responses // to callers before recycling the persistent connection, which meant // a client doing two subsequent requests could end up on different // connections. It's somewhat harmless but enough tests assume it's // not true in order to test other things that it's worth fixing. // Plus it's nice to be consistent and not have timing-dependent // behavior. func TestTransportReuseConnEmptyResponseBody(t *testing.T) { defer afterTest(t) cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Addr", r.RemoteAddr) // Empty response body. })) defer cst.close() n := 100 if testing.Short() { n = 10 } var firstAddr string for i := 0; i < n; i++ { res, err := cst.c.Get(cst.ts.URL) if err != nil { log.Fatal(err) } addr := res.Header.Get("X-Addr") if i == 0 { firstAddr = addr } else if addr != firstAddr { t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr) } res.Body.Close() } } // Issue 13839 func TestNoCrashReturningTransportAltConn(t *testing.T) { cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) if err != nil { t.Fatal(err) } ln := newLocalListener(t) defer ln.Close() var wg sync.WaitGroup SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) defer SetPendingDialHooks(nil, nil) testDone := make(chan struct{}) defer close(testDone) go func() { tln := tls.NewListener(ln, &tls.Config{ NextProtos: []string{"foo"}, Certificates: []tls.Certificate{cert}, }) sc, err := tln.Accept() if err != nil { t.Error(err) return } if err := sc.(*tls.Conn).Handshake(); err != nil { t.Error(err) return } <-testDone sc.Close() }() addr := ln.Addr().String() req, _ := NewRequest("GET", "https://fake.tld/", nil) cancel := make(chan struct{}) req.Cancel = cancel doReturned := make(chan bool, 1) madeRoundTripper := make(chan bool, 1) tr := &Transport{ DisableKeepAlives: true, TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{ "foo": func(authority string, c *tls.Conn) RoundTripper { madeRoundTripper <- true return funcRoundTripper(func() { t.Error("foo RoundTripper should not be called") }) }, }, Dial: func(_, _ string) (net.Conn, error) { panic("shouldn't be called") }, DialTLS: func(_, _ string) (net.Conn, error) { tc, err := tls.Dial("tcp", addr, &tls.Config{ InsecureSkipVerify: true, NextProtos: []string{"foo"}, }) if err != nil { return nil, err } if err := tc.Handshake(); err != nil { return nil, err } close(cancel) <-doReturned return tc, nil }, } c := &Client{Transport: tr} _, err = c.Do(req) if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn { t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err) } doReturned <- true <-madeRoundTripper wg.Wait() } func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) { testTransportReuseConnection_Gzip(t, true) } func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) { testTransportReuseConnection_Gzip(t, false) } // Make sure we re-use underlying TCP connection for gzipped responses too. func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { setParallel(t) defer afterTest(t) addr := make(chan string, 2) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { addr <- r.RemoteAddr w.Header().Set("Content-Encoding", "gzip") if chunked { w.(Flusher).Flush() } w.Write(rgz) // arbitrary gzip response })) defer ts.Close() c := ts.Client() for i := 0; i < 2; i++ { res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } buf := make([]byte, len(rgz)) if n, err := io.ReadFull(res.Body, buf); err != nil { t.Errorf("%d. ReadFull = %v, %v", i, n, err) } // Note: no res.Body.Close call. It should work without it, // since the flate.Reader's internal buffering will hit EOF // and that should be sufficient. } a1, a2 := <-addr, <-addr if a1 != a2 { t.Fatalf("didn't reuse connection") } } func TestTransportResponseHeaderLength(t *testing.T) { setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.URL.Path == "/long" { w.Header().Set("Long", strings.Repeat("a", 1<<20)) } })) defer ts.Close() c := ts.Client() c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10 if res, err := c.Get(ts.URL); err != nil { t.Fatal(err) } else { res.Body.Close() } res, err := c.Get(ts.URL + "/long") if err == nil { defer res.Body.Close() var n int64 for k, vv := range res.Header { for _, v := range vv { n += int64(len(k)) + int64(len(v)) } } t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n) } if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) { t.Errorf("got error: %v; want %q", err, want) } } func TestTransportEventTrace(t *testing.T) { testTransportEventTrace(t, h1Mode, false) } func TestTransportEventTrace_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, false) } // test a non-nil httptrace.ClientTrace but with all hooks set to zero. func TestTransportEventTrace_NoHooks(t *testing.T) { testTransportEventTrace(t, h1Mode, true) } func TestTransportEventTrace_NoHooks_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, true) } func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { defer afterTest(t) const resBody = "some body" gotWroteReqEvent := make(chan struct{}, 500) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method == "GET" { // Do nothing for the second request. return } if _, err := io.ReadAll(r.Body); err != nil { t.Error(err) } if !noHooks { select { case <-gotWroteReqEvent: case <-time.After(5 * time.Second): t.Error("timeout waiting for WroteRequest event") } } io.WriteString(w, resBody) })) defer cst.close() cst.tr.ExpectContinueTimeout = 1 * time.Second var mu sync.Mutex // guards buf var buf bytes.Buffer logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&buf, format, args...) buf.WriteByte('\n') } addrStr := cst.ts.Listener.Addr().String() ip, port, err := net.SplitHostPort(addrStr) if err != nil { t.Fatal(err) } // Install a fake DNS server. ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) { if host != "dns-is-faked.golang" { t.Errorf("unexpected DNS host lookup for %q/%q", network, host) return nil, nil } return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil }) body := "some body" req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body)) req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"} trace := &httptrace.ClientTrace{ GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) }, GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) }, GotFirstResponseByte: func() { logf("first response byte") }, PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) }, DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) }, DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) }, ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) }, ConnectDone: func(network, addr string, err error) { if err != nil { t.Errorf("ConnectDone: %v", err) } logf("ConnectDone: connected to %s %s = %v", network, addr, err) }, WroteHeaderField: func(key string, value []string) { logf("WroteHeaderField: %s: %v", key, value) }, WroteHeaders: func() { logf("WroteHeaders") }, Wait100Continue: func() { logf("Wait100Continue") }, Got100Continue: func() { logf("Got100Continue") }, WroteRequest: func(e httptrace.WroteRequestInfo) { logf("WroteRequest: %+v", e) gotWroteReqEvent <- struct{}{} }, } if h2 { trace.TLSHandshakeStart = func() { logf("tls handshake start") } trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) { logf("tls handshake done. ConnectionState = %v \n err = %v", s, err) } } if noHooks { // zero out all func pointers, trying to get some path to crash *trace = httptrace.ClientTrace{} } req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) req.Header.Set("Expect", "100-continue") res, err := cst.c.Do(req) if err != nil { t.Fatal(err) } logf("got roundtrip.response") slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } logf("consumed body") if string(slurp) != resBody || res.StatusCode != 200 { t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody) } res.Body.Close() if noHooks { // Done at this point. Just testing a full HTTP // requests can happen with a trace pointing to a zero // ClientTrace, full of nil func pointers. return } mu.Lock() got := buf.String() mu.Unlock() wantOnce := func(sub string) { if strings.Count(got, sub) != 1 { t.Errorf("expected substring %q exactly once in output.", sub) } } wantOnceOrMore := func(sub string) { if strings.Count(got, sub) == 0 { t.Errorf("expected substring %q at least once in output.", sub) } } wantOnce("Getting conn for dns-is-faked.golang:" + port) wantOnce("DNS start: {Host:dns-is-faked.golang}") wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err: Coalesced:false}") wantOnce("got conn: {") wantOnceOrMore("Connecting to tcp " + addrStr) wantOnceOrMore("connected to tcp " + addrStr + " = ") wantOnce("Reused:false WasIdle:false IdleTime:0s") wantOnce("first response byte") if h2 { wantOnce("tls handshake start") wantOnce("tls handshake done") } else { wantOnce("PutIdleConn = ") wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]") // TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the // WroteHeaderField hook is not yet implemented in h2.) wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port)) wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body))) wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]") wantOnce("WroteHeaderField: Accept-Encoding: [gzip]") } wantOnce("WroteHeaders") wantOnce("Wait100Continue") wantOnce("Got100Continue") wantOnce("WroteRequest: {Err:}") if strings.Contains(got, " to udp ") { t.Errorf("should not see UDP (DNS) connections") } if t.Failed() { t.Errorf("Output:\n%s", got) } // And do a second request: req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil) req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) res, err = cst.c.Do(req) if err != nil { t.Fatal(err) } if res.StatusCode != 200 { t.Fatal(res.Status) } res.Body.Close() mu.Lock() got = buf.String() mu.Unlock() sub := "Getting conn for dns-is-faked.golang:" if gotn, want := strings.Count(got, sub), 2; gotn != want { t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got) } } func TestTransportEventTraceTLSVerify(t *testing.T) { var mu sync.Mutex var buf bytes.Buffer logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&buf, format, args...) buf.WriteByte('\n') } ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { t.Error("Unexpected request") })) defer ts.Close() ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) { logf("%s", p) return len(p), nil }), "", 0) certpool := x509.NewCertPool() certpool.AddCert(ts.Certificate()) c := &Client{Transport: &Transport{ TLSClientConfig: &tls.Config{ ServerName: "dns-is-faked.golang", RootCAs: certpool, }, }} trace := &httptrace.ClientTrace{ TLSHandshakeStart: func() { logf("TLSHandshakeStart") }, TLSHandshakeDone: func(s tls.ConnectionState, err error) { logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err) }, } req, _ := NewRequest("GET", ts.URL, nil) req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) _, err := c.Do(req) if err == nil { t.Error("Expected request to fail TLS verification") } mu.Lock() got := buf.String() mu.Unlock() wantOnce := func(sub string) { if strings.Count(got, sub) != 1 { t.Errorf("expected substring %q exactly once in output.", sub) } } wantOnce("TLSHandshakeStart") wantOnce("TLSHandshakeDone") wantOnce("err = x509: certificate is valid for example.com") if t.Failed() { t.Errorf("Output:\n%s", got) } } var ( isDNSHijackedOnce sync.Once isDNSHijacked bool ) func skipIfDNSHijacked(t *testing.T) { // Skip this test if the user is using a shady/ISP // DNS server hijacking queries. // See issues 16732, 16716. isDNSHijackedOnce.Do(func() { addrs, _ := net.LookupHost("dns-should-not-resolve.golang") isDNSHijacked = len(addrs) != 0 }) if isDNSHijacked { t.Skip("skipping; test requires non-hijacking DNS server") } } func TestTransportEventTraceRealDNS(t *testing.T) { skipIfDNSHijacked(t) defer afterTest(t) tr := &Transport{} defer tr.CloseIdleConnections() c := &Client{Transport: tr} var mu sync.Mutex // guards buf var buf bytes.Buffer logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&buf, format, args...) buf.WriteByte('\n') } req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil) trace := &httptrace.ClientTrace{ DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) }, DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) }, ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) }, ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) }, } req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) resp, err := c.Do(req) if err == nil { resp.Body.Close() t.Fatal("expected error during DNS lookup") } mu.Lock() got := buf.String() mu.Unlock() wantSub := func(sub string) { if !strings.Contains(got, sub) { t.Errorf("expected substring %q in output.", sub) } } wantSub("DNSStart: {Host:dns-should-not-resolve.golang}") wantSub("DNSDone: {Addrs:[] Err:") if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") { t.Errorf("should not see Connect events") } if t.Failed() { t.Errorf("Output:\n%s", got) } } // Issue 14353: port can only contain digits. func TestTransportRejectsAlphaPort(t *testing.T) { res, err := Get("http://dummy.tld:123foo/bar") if err == nil { res.Body.Close() t.Fatal("unexpected success") } ue, ok := err.(*url.Error) if !ok { t.Fatalf("got %#v; want *url.Error", err) } got := ue.Err.Error() want := `invalid port ":123foo" after host` if got != want { t.Errorf("got error %q; want %q", got, want) } } // Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1 // connections. The http2 test is done in TestTransportEventTrace_h2 func TestTLSHandshakeTrace(t *testing.T) { defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() var mu sync.Mutex var start, done bool trace := &httptrace.ClientTrace{ TLSHandshakeStart: func() { mu.Lock() defer mu.Unlock() start = true }, TLSHandshakeDone: func(s tls.ConnectionState, err error) { mu.Lock() defer mu.Unlock() done = true if err != nil { t.Fatal("Expected error to be nil but was:", err) } }, } c := ts.Client() req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal("Unable to construct test request:", err) } req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) r, err := c.Do(req) if err != nil { t.Fatal("Unexpected error making request:", err) } r.Body.Close() mu.Lock() defer mu.Unlock() if !start { t.Fatal("Expected TLSHandshakeStart to be called, but wasn't") } if !done { t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't") } } func TestTransportMaxIdleConns(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { // No body for convenience. })) defer ts.Close() c := ts.Client() tr := c.Transport.(*Transport) tr.MaxIdleConns = 4 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String()) if err != nil { t.Fatal(err) } ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) { return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil }) hitHost := func(n int) { req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil) req = req.WithContext(ctx) res, err := c.Do(req) if err != nil { t.Fatal(err) } res.Body.Close() } for i := 0; i < 4; i++ { hitHost(i) } want := []string{ "|http|host-0.dns-is-faked.golang:" + port, "|http|host-1.dns-is-faked.golang:" + port, "|http|host-2.dns-is-faked.golang:" + port, "|http|host-3.dns-is-faked.golang:" + port, } if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) { t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want) } // Now hitting the 5th host should kick out the first host: hitHost(4) want = []string{ "|http|host-1.dns-is-faked.golang:" + port, "|http|host-2.dns-is-faked.golang:" + port, "|http|host-3.dns-is-faked.golang:" + port, "|http|host-4.dns-is-faked.golang:" + port, } if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) { t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want) } } func TestTransportIdleConnTimeout_h1(t *testing.T) { testTransportIdleConnTimeout(t, h1Mode) } func TestTransportIdleConnTimeout_h2(t *testing.T) { testTransportIdleConnTimeout(t, h2Mode) } func testTransportIdleConnTimeout(t *testing.T, h2 bool) { if testing.Short() { t.Skip("skipping in short mode") } defer afterTest(t) const timeout = 1 * time.Second cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { // No body for convenience. })) defer cst.close() tr := cst.tr tr.IdleConnTimeout = timeout defer tr.CloseIdleConnections() c := &Client{Transport: tr} idleConns := func() []string { if h2 { return tr.IdleConnStrsForTesting_h2() } else { return tr.IdleConnStrsForTesting() } } var conn string doReq := func(n int) { req, _ := NewRequest("GET", cst.ts.URL, nil) req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ PutIdleConn: func(err error) { if err != nil { t.Errorf("failed to keep idle conn: %v", err) } }, })) res, err := c.Do(req) if err != nil { t.Fatal(err) } res.Body.Close() conns := idleConns() if len(conns) != 1 { t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns) } if conn == "" { conn = conns[0] } if conn != conns[0] { t.Fatalf("req %v: cached connection changed; expected the same one throughout the test", n) } } for i := 0; i < 3; i++ { doReq(i) time.Sleep(timeout / 2) } time.Sleep(timeout * 3 / 2) if got := idleConns(); len(got) != 0 { t.Errorf("idle conns = %q; want none", got) } } // Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an // HTTP/2 connection was established but its caller no longer // wanted it. (Assuming the connection cache was enabled, which it is // by default) // // This test reproduced the crash by setting the IdleConnTimeout low // (to make the test reasonable) and then making a request which is // canceled by the DialTLS hook, which then also waits to return the // real connection until after the RoundTrip saw the error. Then we // know the successful tls.Dial from DialTLS will need to go into the // idle pool. Then we give it a of time to explode. func TestIdleConnH2Crash(t *testing.T) { setParallel(t) cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { // nothing })) defer cst.close() ctx, cancel := context.WithCancel(context.Background()) defer cancel() sawDoErr := make(chan bool, 1) testDone := make(chan struct{}) defer close(testDone) cst.tr.IdleConnTimeout = 5 * time.Millisecond cst.tr.DialTLS = func(network, addr string) (net.Conn, error) { c, err := tls.Dial(network, addr, &tls.Config{ InsecureSkipVerify: true, NextProtos: []string{"h2"}, }) if err != nil { t.Error(err) return nil, err } if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" { t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2") c.Close() return nil, errors.New("bogus") } cancel() failTimer := time.NewTimer(5 * time.Second) defer failTimer.Stop() select { case <-sawDoErr: case <-testDone: case <-failTimer.C: t.Error("timeout in DialTLS, waiting too long for cst.c.Do to fail") } return c, nil } req, _ := NewRequest("GET", cst.ts.URL, nil) req = req.WithContext(ctx) res, err := cst.c.Do(req) if err == nil { res.Body.Close() t.Fatal("unexpected success") } sawDoErr <- true // Wait for the explosion. time.Sleep(cst.tr.IdleConnTimeout * 10) } type funcConn struct { net.Conn read func([]byte) (int, error) write func([]byte) (int, error) } func (c funcConn) Read(p []byte) (int, error) { return c.read(p) } func (c funcConn) Write(p []byte) (int, error) { return c.write(p) } func (c funcConn) Close() error { return nil } // Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek // back to the caller. func TestTransportReturnsPeekError(t *testing.T) { errValue := errors.New("specific error value") wrote := make(chan struct{}) var wroteOnce sync.Once tr := &Transport{ Dial: func(network, addr string) (net.Conn, error) { c := funcConn{ read: func([]byte) (int, error) { <-wrote return 0, errValue }, write: func(p []byte) (int, error) { wroteOnce.Do(func() { close(wrote) }) return len(p), nil }, } return c, nil }, } _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil)) if err != errValue { t.Errorf("error = %#v; want %v", err, errValue) } } // Issue 13835: international domain names should work func TestTransportIDNA_h1(t *testing.T) { testTransportIDNA(t, h1Mode) } func TestTransportIDNA_h2(t *testing.T) { testTransportIDNA(t, h2Mode) } func testTransportIDNA(t *testing.T, h2 bool) { defer afterTest(t) const uniDomain = "гофер.го" const punyDomain = "xn--c1ae0ajs.xn--c1aw" var port string cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { want := punyDomain + ":" + port if r.Host != want { t.Errorf("Host header = %q; want %q", r.Host, want) } if h2 { if r.TLS == nil { t.Errorf("r.TLS == nil") } else if r.TLS.ServerName != punyDomain { t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain) } } w.Header().Set("Hit-Handler", "1") })) defer cst.close() ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String()) if err != nil { t.Fatal(err) } // Install a fake DNS server. ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) { if host != punyDomain { t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain) return nil, nil } return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil }) req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil) trace := &httptrace.ClientTrace{ GetConn: func(hostPort string) { want := net.JoinHostPort(punyDomain, port) if hostPort != want { t.Errorf("getting conn for %q; want %q", hostPort, want) } }, DNSStart: func(e httptrace.DNSStartInfo) { if e.Host != punyDomain { t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain) } }, } req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) res, err := cst.tr.RoundTrip(req) if err != nil { t.Fatal(err) } defer res.Body.Close() if res.Header.Get("Hit-Handler") != "1" { out, err := httputil.DumpResponse(res, true) if err != nil { t.Fatal(err) } t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out) } } // Issue 13290: send User-Agent in proxy CONNECT func TestTransportProxyConnectHeader(t *testing.T) { defer afterTest(t) reqc := make(chan *Request, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "CONNECT" { t.Errorf("method = %q; want CONNECT", r.Method) } reqc <- r c, _, err := w.(Hijacker).Hijack() if err != nil { t.Errorf("Hijack: %v", err) return } c.Close() })) defer ts.Close() c := ts.Client() c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { return url.Parse(ts.URL) } c.Transport.(*Transport).ProxyConnectHeader = Header{ "User-Agent": {"foo"}, "Other": {"bar"}, } res, err := c.Get("https://dummy.tld/") // https to force a CONNECT if err == nil { res.Body.Close() t.Errorf("unexpected success") } select { case <-time.After(3 * time.Second): t.Fatal("timeout") case r := <-reqc: if got, want := r.Header.Get("User-Agent"), "foo"; got != want { t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) } if got, want := r.Header.Get("Other"), "bar"; got != want { t.Errorf("CONNECT request Other = %q; want %q", got, want) } } } func TestTransportProxyGetConnectHeader(t *testing.T) { defer afterTest(t) reqc := make(chan *Request, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "CONNECT" { t.Errorf("method = %q; want CONNECT", r.Method) } reqc <- r c, _, err := w.(Hijacker).Hijack() if err != nil { t.Errorf("Hijack: %v", err) return } c.Close() })) defer ts.Close() c := ts.Client() c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { return url.Parse(ts.URL) } // These should be ignored: c.Transport.(*Transport).ProxyConnectHeader = Header{ "User-Agent": {"foo"}, "Other": {"bar"}, } c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) { return Header{ "User-Agent": {"foo2"}, "Other": {"bar2"}, }, nil } res, err := c.Get("https://dummy.tld/") // https to force a CONNECT if err == nil { res.Body.Close() t.Errorf("unexpected success") } select { case <-time.After(3 * time.Second): t.Fatal("timeout") case r := <-reqc: if got, want := r.Header.Get("User-Agent"), "foo2"; got != want { t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) } if got, want := r.Header.Get("Other"), "bar2"; got != want { t.Errorf("CONNECT request Other = %q; want %q", got, want) } } } var errFakeRoundTrip = errors.New("fake roundtrip") type funcRoundTripper func() func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) { fn() return nil, errFakeRoundTrip } func wantBody(res *Response, err error, want string) error { if err != nil { return err } slurp, err := io.ReadAll(res.Body) if err != nil { return fmt.Errorf("error reading body: %v", err) } if string(slurp) != want { return fmt.Errorf("body = %q; want %q", slurp, want) } if err := res.Body.Close(); err != nil { return fmt.Errorf("body Close = %v", err) } return nil } func newLocalListener(t *testing.T) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { ln, err = net.Listen("tcp6", "[::1]:0") } if err != nil { t.Fatal(err) } return ln } type countCloseReader struct { n *int io.Reader } func (cr countCloseReader) Close() error { (*cr.n)++ return nil } // rgz is a gzip quine that uncompresses to itself. var rgz = []byte{ 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73, 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0, 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, 0x00, 0x00, } // Ensure that a missing status doesn't make the server panic // See Issue https://golang.org/issues/21701 func TestMissingStatusNoPanic(t *testing.T) { t.Parallel() const want = "unknown status code" ln := newLocalListener(t) addr := ln.Addr().String() done := make(chan bool) fullAddrURL := fmt.Sprintf("http://%s", addr) raw := "HTTP/1.1 400\r\n" + "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + "Content-Type: text/html; charset=utf-8\r\n" + "Content-Length: 10\r\n" + "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" + "Vary: Accept-Encoding\r\n\r\n" + "Aloha Olaa" go func() { defer close(done) conn, _ := ln.Accept() if conn != nil { io.WriteString(conn, raw) io.ReadAll(conn) conn.Close() } }() proxyURL, err := url.Parse(fullAddrURL) if err != nil { t.Fatalf("proxyURL: %v", err) } tr := &Transport{Proxy: ProxyURL(proxyURL)} req, _ := NewRequest("GET", "https://golang.org/", nil) res, err, panicked := doFetchCheckPanic(tr, req) if panicked { t.Error("panicked, expecting an error") } if res != nil && res.Body != nil { io.Copy(io.Discard, res.Body) res.Body.Close() } if err == nil || !strings.Contains(err.Error(), want) { t.Errorf("got=%v want=%q", err, want) } ln.Close() <-done } func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) { defer func() { if r := recover(); r != nil { panicked = true } }() res, err = tr.RoundTrip(req) return } // Issue 22330: do not allow the response body to be read when the status code // forbids a response body. func TestNoBodyOnChunked304Response(t *testing.T) { defer afterTest(t) cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n")) buf.Flush() conn.Close() })) defer cst.close() // Our test server above is sending back bogus data after the // response (the "0\r\n\r\n" part), which causes the Transport // code to log spam. Disable keep-alives so we never even try // to reuse the connection. cst.tr.DisableKeepAlives = true res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } if res.Body != NoBody { t.Errorf("Unexpected body on 304 response") } } type funcWriter func([]byte) (int, error) func (f funcWriter) Write(p []byte) (int, error) { return f(p) } type doneContext struct { context.Context err error } func (doneContext) Done() <-chan struct{} { c := make(chan struct{}) close(c) return c } func (d doneContext) Err() error { return d.err } // Issue 25852: Transport should check whether Context is done early. func TestTransportCheckContextDoneEarly(t *testing.T) { tr := &Transport{} req, _ := NewRequest("GET", "http://fake.example/", nil) wantErr := errors.New("some error") req = req.WithContext(doneContext{context.Background(), wantErr}) _, err := tr.RoundTrip(req) if err != wantErr { t.Errorf("error = %v; want %v", err, wantErr) } } // Issue 23399: verify that if a client request times out, the Transport's // conn is closed so that it's not reused. // // This is the test variant that times out before the server replies with // any response headers. func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { setParallel(t) defer afterTest(t) inHandler := make(chan net.Conn, 1) handlerReadReturned := make(chan bool, 1) cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) return } inHandler <- conn n, err := conn.Read([]byte{0}) if n != 0 || err != io.EOF { t.Errorf("unexpected Read result: %v, %v", n, err) } handlerReadReturned <- true })) defer cst.close() const timeout = 50 * time.Millisecond cst.c.Timeout = timeout _, err := cst.c.Get(cst.ts.URL) if err == nil { t.Fatal("unexpected Get succeess") } select { case c := <-inHandler: select { case <-handlerReadReturned: // Success. return case <-time.After(5 * time.Second): t.Error("Handler's conn.Read seems to be stuck in Read") c.Close() // close it to unblock Handler } case <-time.After(timeout * 10): // If we didn't get into the Handler in 50ms, that probably means // the builder was just slow and the Get failed in that time // but never made it to the server. That's fine. We'll usually // test the part above on faster machines. t.Skip("skipping test on slow builder") } } // Issue 23399: verify that if a client request times out, the Transport's // conn is closed so that it's not reused. // // This is the test variant that has the server send response headers // first, and time out during the write of the response body. func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { setParallel(t) defer afterTest(t) inHandler := make(chan net.Conn, 1) handlerResult := make(chan error, 1) cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "100") w.(Flusher).Flush() conn, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) return } conn.Write([]byte("foo")) inHandler <- conn n, err := conn.Read([]byte{0}) // The error should be io.EOF or "read tcp // 127.0.0.1:35827->127.0.0.1:40290: read: connection // reset by peer" depending on timing. Really we just // care that it returns at all. But if it returns with // data, that's weird. if n != 0 || err == nil { handlerResult <- fmt.Errorf("unexpected Read result: %v, %v", n, err) return } handlerResult <- nil })) defer cst.close() // Set Timeout to something very long but non-zero to exercise // the codepaths that check for it. But rather than wait for it to fire // (which would make the test slow), we send on the req.Cancel channel instead, // which happens to exercise the same code paths. cst.c.Timeout = time.Minute // just to be non-zero, not to hit it. req, _ := NewRequest("GET", cst.ts.URL, nil) cancel := make(chan struct{}) req.Cancel = cancel res, err := cst.c.Do(req) if err != nil { select { case <-inHandler: t.Fatalf("Get error: %v", err) default: // Failed before entering handler. Ignore result. t.Skip("skipping test on slow builder") } } close(cancel) got, err := io.ReadAll(res.Body) if err == nil { t.Fatalf("unexpected success; read %q, nil", got) } select { case c := <-inHandler: select { case err := <-handlerResult: if err != nil { t.Errorf("handler: %v", err) } return case <-time.After(5 * time.Second): t.Error("Handler's conn.Read seems to be stuck in Read") c.Close() // close it to unblock Handler } case <-time.After(5 * time.Second): t.Fatal("timeout") } } func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { setParallel(t) defer afterTest(t) done := make(chan struct{}) defer close(done) cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) return } defer conn.Close() io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n") bs := bufio.NewScanner(conn) bs.Scan() fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text())) <-done })) defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, nil) req.Header.Set("Upgrade", "foo") req.Header.Set("Connection", "upgrade") res, err := cst.c.Do(req) if err != nil { t.Fatal(err) } if res.StatusCode != 101 { t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header) } rwc, ok := res.Body.(io.ReadWriteCloser) if !ok { t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body) } defer rwc.Close() bs := bufio.NewScanner(rwc) if !bs.Scan() { t.Fatalf("expected readable input") } if got, want := bs.Text(), "Some buffered data"; got != want { t.Errorf("read %q; want %q", got, want) } io.WriteString(rwc, "echo\n") if !bs.Scan() { t.Fatalf("expected another line") } if got, want := bs.Text(), "ECHO"; got != want { t.Errorf("read %q; want %q", got, want) } } func TestTransportCONNECTBidi(t *testing.T) { defer afterTest(t) const target = "backend:443" cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "CONNECT" { t.Errorf("unexpected method %q", r.Method) w.WriteHeader(500) return } if r.RequestURI != target { t.Errorf("unexpected CONNECT target %q", r.RequestURI) w.WriteHeader(500) return } nc, brw, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) return } defer nc.Close() nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n")) // Switch to a little protocol that capitalize its input lines: for { line, err := brw.ReadString('\n') if err != nil { if err != io.EOF { t.Error(err) } return } io.WriteString(brw, strings.ToUpper(line)) brw.Flush() } })) defer cst.close() pr, pw := io.Pipe() defer pw.Close() req, err := NewRequest("CONNECT", cst.ts.URL, pr) if err != nil { t.Fatal(err) } req.URL.Opaque = target res, err := cst.c.Do(req) if err != nil { t.Fatal(err) } defer res.Body.Close() if res.StatusCode != 200 { t.Fatalf("status code = %d; want 200", res.StatusCode) } br := bufio.NewReader(res.Body) for _, str := range []string{"foo", "bar", "baz"} { fmt.Fprintf(pw, "%s\n", str) got, err := br.ReadString('\n') if err != nil { t.Fatal(err) } got = strings.TrimSpace(got) want := strings.ToUpper(str) if got != want { t.Fatalf("got %q; want %q", got, want) } } } func TestTransportRequestReplayable(t *testing.T) { someBody := io.NopCloser(strings.NewReader("")) tests := []struct { name string req *Request want bool }{ { name: "GET", req: &Request{Method: "GET"}, want: true, }, { name: "GET_http.NoBody", req: &Request{Method: "GET", Body: NoBody}, want: true, }, { name: "GET_body", req: &Request{Method: "GET", Body: someBody}, want: false, }, { name: "POST", req: &Request{Method: "POST"}, want: false, }, { name: "POST_idempotency-key", req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}}, want: true, }, { name: "POST_x-idempotency-key", req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}}, want: true, }, { name: "POST_body", req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody}, want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := tt.req.ExportIsReplayable() if got != tt.want { t.Errorf("replyable = %v; want %v", got, tt.want) } }) } } // testMockTCPConn is a mock TCP connection used to test that // ReadFrom is called when sending the request body. type testMockTCPConn struct { *net.TCPConn ReadFromCalled bool } func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) { c.ReadFromCalled = true return c.TCPConn.ReadFrom(r) } func TestTransportRequestWriteRoundTrip(t *testing.T) { nBytes := int64(1 << 10) newFileFunc := func() (r io.Reader, done func(), err error) { f, err := os.CreateTemp("", "net-http-newfilefunc") if err != nil { return nil, nil, err } // Write some bytes to the file to enable reading. if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { return nil, nil, fmt.Errorf("failed to write data to file: %v", err) } if _, err := f.Seek(0, 0); err != nil { return nil, nil, fmt.Errorf("failed to seek to front: %v", err) } done = func() { f.Close() os.Remove(f.Name()) } return f, done, nil } newBufferFunc := func() (io.Reader, func(), error) { return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil } cases := []struct { name string readerFunc func() (io.Reader, func(), error) contentLength int64 expectedReadFrom bool }{ { name: "file, length", readerFunc: newFileFunc, contentLength: nBytes, expectedReadFrom: true, }, { name: "file, no length", readerFunc: newFileFunc, }, { name: "file, negative length", readerFunc: newFileFunc, contentLength: -1, }, { name: "buffer", contentLength: nBytes, readerFunc: newBufferFunc, }, { name: "buffer, no length", readerFunc: newBufferFunc, }, { name: "buffer, length -1", contentLength: -1, readerFunc: newBufferFunc, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { r, cleanup, err := tc.readerFunc() if err != nil { t.Fatal(err) } defer cleanup() tConn := &testMockTCPConn{} trFunc := func(tr *Transport) { tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { var d net.Dialer conn, err := d.DialContext(ctx, network, addr) if err != nil { return nil, err } tcpConn, ok := conn.(*net.TCPConn) if !ok { return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr) } tConn.TCPConn = tcpConn return tConn, nil } } cst := newClientServerTest( t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.Copy(io.Discard, r.Body) r.Body.Close() w.WriteHeader(200) }), trFunc, ) defer cst.close() req, err := NewRequest("PUT", cst.ts.URL, r) if err != nil { t.Fatal(err) } req.ContentLength = tc.contentLength req.Header.Set("Content-Type", "application/octet-stream") resp, err := cst.c.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Fatalf("status code = %d; want 200", resp.StatusCode) } if !tConn.ReadFromCalled && tc.expectedReadFrom { t.Fatalf("did not call ReadFrom") } if tConn.ReadFromCalled && !tc.expectedReadFrom { t.Fatalf("ReadFrom was unexpectedly invoked") } }) } } func TestTransportClone(t *testing.T) { tr := &Transport{ Proxy: func(*Request) (*url.URL, error) { panic("") }, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, Dial: func(network, addr string) (net.Conn, error) { panic("") }, DialTLS: func(network, addr string) (net.Conn, error) { panic("") }, DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, TLSClientConfig: new(tls.Config), TLSHandshakeTimeout: time.Second, DisableKeepAlives: true, DisableCompression: true, MaxIdleConns: 1, MaxIdleConnsPerHost: 1, MaxConnsPerHost: 1, IdleConnTimeout: time.Second, ResponseHeaderTimeout: time.Second, ExpectContinueTimeout: time.Second, ProxyConnectHeader: Header{}, GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil }, MaxResponseHeaderBytes: 1, ForceAttemptHTTP2: true, TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{ "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") }, }, ReadBufferSize: 1, WriteBufferSize: 1, } tr2 := tr.Clone() rv := reflect.ValueOf(tr2).Elem() rt := rv.Type() for i := 0; i < rt.NumField(); i++ { sf := rt.Field(i) if !token.IsExported(sf.Name) { continue } if rv.Field(i).IsZero() { t.Errorf("cloned field t2.%s is zero", sf.Name) } } if _, ok := tr2.TLSNextProto["foo"]; !ok { t.Errorf("cloned Transport lacked TLSNextProto 'foo' key") } // But test that a nil TLSNextProto is kept nil: tr = new(Transport) tr2 = tr.Clone() if tr2.TLSNextProto != nil { t.Errorf("Transport.TLSNextProto unexpected non-nil") } } func TestIs408(t *testing.T) { tests := []struct { in string want bool }{ {"HTTP/1.0 408", true}, {"HTTP/1.1 408", true}, {"HTTP/1.8 408", true}, {"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now. {"HTTP/1.1 408 ", true}, {"HTTP/1.1 40", false}, {"http/1.0 408", false}, {"HTTP/1-1 408", false}, } for _, tt := range tests { if got := Export_is408Message([]byte(tt.in)); got != tt.want { t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want) } } } func TestTransportIgnores408(t *testing.T) { // Not parallel. Relies on mutating the log package's global Output. defer log.SetOutput(log.Writer()) var logout bytes.Buffer log.SetOutput(&logout) defer afterTest(t) const target = "backend:443" cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { nc, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) return } defer nc.Close() nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")) nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail })) defer cst.close() req, err := NewRequest("GET", cst.ts.URL, nil) if err != nil { t.Fatal(err) } res, err := cst.c.Do(req) if err != nil { t.Fatal(err) } slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } if err != nil { t.Fatal(err) } if string(slurp) != "ok" { t.Fatalf("got %q; want ok", slurp) } t0 := time.Now() for i := 0; i < 50; i++ { time.Sleep(time.Duration(i) * 5 * time.Millisecond) if cst.tr.IdleConnKeyCountForTesting() == 0 { if got := logout.String(); got != "" { t.Fatalf("expected no log output; got: %s", got) } return } } t.Fatalf("timeout after %v waiting for Transport connections to die off", time.Since(t0)) } func TestInvalidHeaderResponse(t *testing.T) { setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 200 OK\r\n" + "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + "Content-Type: text/html; charset=utf-8\r\n" + "Content-Length: 0\r\n" + "Foo : bar\r\n\r\n")) buf.Flush() conn.Close() })) defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } defer res.Body.Close() if v := res.Header.Get("Foo"); v != "" { t.Errorf(`unexpected "Foo" header: %q`, v) } if v := res.Header.Get("Foo "); v != "bar" { t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar") } } type bodyCloser bool func (bc *bodyCloser) Close() error { *bc = true return nil } func (bc *bodyCloser) Read(b []byte) (n int, err error) { return 0, io.EOF } // Issue 35015: ensure that Transport closes the body on any error // with an invalid request, as promised by Client.Do docs. func TestTransportClosesBodyOnInvalidRequests(t *testing.T) { cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { t.Errorf("Should not have been invoked") })) defer cst.Close() u, _ := url.Parse(cst.URL) tests := []struct { name string req *Request wantErr string }{ { name: "invalid method", req: &Request{ Method: " ", URL: u, }, wantErr: "invalid method", }, { name: "nil URL", req: &Request{ Method: "GET", }, wantErr: "nil Request.URL", }, { name: "invalid header key", req: &Request{ Method: "GET", Header: Header{"💡": {"emoji"}}, URL: u, }, wantErr: "invalid header field name", }, { name: "invalid header value", req: &Request{ Method: "POST", Header: Header{"key": {"\x19"}}, URL: u, }, wantErr: "invalid header field value", }, { name: "non HTTP(s) scheme", req: &Request{ Method: "POST", URL: &url.URL{Scheme: "faux"}, }, wantErr: "unsupported protocol scheme", }, { name: "no Host in URL", req: &Request{ Method: "POST", URL: &url.URL{Scheme: "http"}, }, wantErr: "no Host", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var bc bodyCloser req := tt.req req.Body = &bc _, err := DefaultClient.Do(tt.req) if err == nil { t.Fatal("Expected an error") } if !bc { t.Fatal("Expected body to have been closed") } if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) { t.Fatalf("Error mismatch\n\t%q\ndoes not contain\n\t%q", g, w) } }) } } // breakableConn is a net.Conn wrapper with a Write method // that will fail when its brokenState is true. type breakableConn struct { net.Conn *brokenState } type brokenState struct { sync.Mutex broken bool } func (w *breakableConn) Write(b []byte) (n int, err error) { w.Lock() defer w.Unlock() if w.broken { return 0, errors.New("some write error") } return w.Conn.Write(b) } // Issue 34978: don't cache a broken HTTP/2 connection func TestDontCacheBrokenHTTP2Conn(t *testing.T) { cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog) defer cst.close() var brokenState brokenState const numReqs = 5 var numDials, gotConns uint32 // atomic cst.tr.Dial = func(netw, addr string) (net.Conn, error) { atomic.AddUint32(&numDials, 1) c, err := net.Dial(netw, addr) if err != nil { t.Errorf("unexpected Dial error: %v", err) return nil, err } return &breakableConn{c, &brokenState}, err } for i := 1; i <= numReqs; i++ { brokenState.Lock() brokenState.broken = false brokenState.Unlock() // doBreak controls whether we break the TCP connection after the TLS // handshake (before the HTTP/2 handshake). We test a few failures // in a row followed by a final success. doBreak := i != numReqs ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ GotConn: func(info httptrace.GotConnInfo) { t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime) atomic.AddUint32(&gotConns, 1) }, TLSHandshakeDone: func(cfg tls.ConnectionState, err error) { brokenState.Lock() defer brokenState.Unlock() if doBreak { brokenState.broken = true } }, }) req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil) if err != nil { t.Fatal(err) } _, err = cst.c.Do(req) if doBreak != (err != nil) { t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err) } } if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want { t.Errorf("GotConn calls = %v; want %v", got, want) } if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want { t.Errorf("Dials = %v; want %v", got, want) } } // Issue 34941 // When the client has too many concurrent requests on a single connection, // http.http2noCachedConnError is reported on multiple requests. There should // only be one decrement regardless of the number of failures. func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { defer afterTest(t) CondSkipHTTP2(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("foo")) if err != nil { t.Fatalf("Write: %v", err) } }) ts := httptest.NewUnstartedServer(h) ts.EnableHTTP2 = true ts.StartTLS() defer ts.Close() c := ts.Client() tr := c.Transport.(*Transport) tr.MaxConnsPerHost = 1 if err := ExportHttp2ConfigureTransport(tr); err != nil { t.Fatalf("ExportHttp2ConfigureTransport: %v", err) } errCh := make(chan error, 300) doReq := func() { resp, err := c.Get(ts.URL) if err != nil { errCh <- fmt.Errorf("request failed: %v", err) return } defer resp.Body.Close() _, err = io.ReadAll(resp.Body) if err != nil { errCh <- fmt.Errorf("read body failed: %v", err) } } var wg sync.WaitGroup for i := 0; i < 300; i++ { wg.Add(1) go func() { defer wg.Done() doReq() }() } wg.Wait() close(errCh) for err := range errCh { t.Errorf("error occurred: %v", err) } } // Issue 36820 // Test that we use the older backward compatible cancellation protocol // when a RoundTripper is registered via RegisterProtocol. func TestAltProtoCancellation(t *testing.T) { defer afterTest(t) tr := &Transport{} c := &Client{ Transport: tr, Timeout: time.Millisecond, } tr.RegisterProtocol("timeout", timeoutProto{}) _, err := c.Get("timeout://bar.com/path") if err == nil { t.Error("request unexpectedly succeeded") } else if !strings.Contains(err.Error(), timeoutProtoErr.Error()) { t.Errorf("got error %q, does not contain expected string %q", err, timeoutProtoErr) } } var timeoutProtoErr = errors.New("canceled as expected") type timeoutProto struct{} func (timeoutProto) RoundTrip(req *Request) (*Response, error) { select { case <-req.Cancel: return nil, timeoutProtoErr case <-time.After(5 * time.Second): return nil, errors.New("request was not canceled") } } type roundTripFunc func(r *Request) (*Response, error) func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) } // Issue 32441: body is not reset after ErrSkipAltProtocol func TestIssue32441(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if n, _ := io.Copy(io.Discard, r.Body); n == 0 { t.Error("body length is zero") } })) defer ts.Close() c := ts.Client() c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) { // Draining body to trigger failure condition on actual request to server. if n, _ := io.Copy(io.Discard, r.Body); n == 0 { t.Error("body length is zero during round trip") } return nil, ErrSkipAltProtocol })) if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil { t.Error(err) } } // Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers // that contain a sign (eg. "+3"), per RFC 2616, Section 14.13. func TestTransportRejectsSignInContentLength(t *testing.T) { cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "+3") w.Write([]byte("abc")) })) defer cst.Close() c := cst.Client() res, err := c.Get(cst.URL) if err == nil || res != nil { t.Fatal("Expected a non-nil error and a nil http.Response") } if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) { t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want) } } // dumpConn is a net.Conn which writes to Writer and reads from Reader type dumpConn struct { io.Writer io.Reader } func (c *dumpConn) Close() error { return nil } func (c *dumpConn) LocalAddr() net.Addr { return nil } func (c *dumpConn) RemoteAddr() net.Addr { return nil } func (c *dumpConn) SetDeadline(t time.Time) error { return nil } func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } // delegateReader is a reader that delegates to another reader, // once it arrives on a channel. type delegateReader struct { c chan io.Reader r io.Reader // nil until received from c } func (r *delegateReader) Read(p []byte) (int, error) { if r.r == nil { var ok bool if r.r, ok = <-r.c; !ok { return 0, errors.New("delegate closed") } } return r.r.Read(p) } func testTransportRace(req *Request) { save := req.Body pr, pw := io.Pipe() defer pr.Close() defer pw.Close() dr := &delegateReader{c: make(chan io.Reader)} t := &Transport{ Dial: func(net, addr string) (net.Conn, error) { return &dumpConn{pw, dr}, nil }, } defer t.CloseIdleConnections() quitReadCh := make(chan struct{}) // Wait for the request before replying with a dummy response: go func() { defer close(quitReadCh) req, err := ReadRequest(bufio.NewReader(pr)) if err == nil { // Ensure all the body is read; otherwise // we'll get a partial dump. io.Copy(io.Discard, req.Body) req.Body.Close() } select { case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"): case quitReadCh <- struct{}{}: // Ensure delegate is closed so Read doesn't block forever. close(dr.c) } }() t.RoundTrip(req) // Ensure the reader returns before we reset req.Body to prevent // a data race on req.Body. pw.Close() <-quitReadCh req.Body = save } // Issue 37669 // Test that a cancellation doesn't result in a data race due to the writeLoop // goroutine being left running, if the caller mutates the processed Request // upon completion. func TestErrorWriteLoopRace(t *testing.T) { if testing.Short() { return } t.Parallel() for i := 0; i < 1000; i++ { delay := time.Duration(mrand.Intn(5)) * time.Millisecond ctx, cancel := context.WithTimeout(context.Background(), delay) defer cancel() r := bytes.NewBuffer(make([]byte, 10000)) req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r) if err != nil { t.Fatal(err) } testTransportRace(req) } } // Issue 41600 // Test that a new request which uses the connection of an active request // cannot cause it to be canceled as well. func TestCancelRequestWhenSharingConnection(t *testing.T) { reqc := make(chan chan struct{}, 2) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, req *Request) { ch := make(chan struct{}, 1) reqc <- ch <-ch w.Header().Add("Content-Length", "0") })) defer ts.Close() client := ts.Client() transport := client.Transport.(*Transport) transport.MaxIdleConns = 1 transport.MaxConnsPerHost = 1 var wg sync.WaitGroup wg.Add(1) putidlec := make(chan chan struct{}) go func() { defer wg.Done() ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ PutIdleConn: func(error) { // Signal that the idle conn has been returned to the pool, // and wait for the order to proceed. ch := make(chan struct{}) putidlec <- ch <-ch }, }) req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil) res, err := client.Do(req) if err == nil { res.Body.Close() } if err != nil { t.Errorf("request 1: got err %v, want nil", err) } }() // Wait for the first request to receive a response and return the // connection to the idle pool. r1c := <-reqc close(r1c) idlec := <-putidlec wg.Add(1) cancelctx, cancel := context.WithCancel(context.Background()) go func() { defer wg.Done() req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil) res, err := client.Do(req) if err == nil { res.Body.Close() } if !errors.Is(err, context.Canceled) { t.Errorf("request 2: got err %v, want Canceled", err) } }() // Wait for the second request to arrive at the server, and then cancel // the request context. r2c := <-reqc cancel() // Give the cancelation a moment to take effect, and then unblock the first request. time.Sleep(1 * time.Millisecond) close(idlec) close(r2c) wg.Wait() } func TestHandlerAbortRacesBodyRead(t *testing.T) { setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { go io.Copy(io.Discard, req.Body) panic(ErrAbortHandler) })) defer ts.Close() var wg sync.WaitGroup for i := 0; i < 2; i++ { wg.Add(1) go func() { defer wg.Done() for j := 0; j < 10; j++ { const reqLen = 6 * 1024 * 1024 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) req.ContentLength = reqLen resp, _ := ts.Client().Transport.RoundTrip(req) if resp != nil { resp.Body.Close() } } }() } wg.Wait() }