Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 "crypto/tls"
16 "encoding/json"
17 "errors"
18 "fmt"
19 "internal/testenv"
20 "io"
21 "log"
22 "math/rand"
23 "net"
24 . "net/http"
25 "net/http/httptest"
26 "net/http/httptrace"
27 "net/http/httputil"
28 "net/http/internal"
29 "net/http/internal/testcert"
30 "net/url"
31 "os"
32 "os/exec"
33 "path/filepath"
34 "reflect"
35 "regexp"
36 "runtime"
37 "runtime/debug"
38 "strconv"
39 "strings"
40 "sync"
41 "sync/atomic"
42 "syscall"
43 "testing"
44 "time"
45 )
46
47 type dummyAddr string
48 type oneConnListener struct {
49 conn net.Conn
50 }
51
52 func (l *oneConnListener) Accept() (c net.Conn, err error) {
53 c = l.conn
54 if c == nil {
55 err = io.EOF
56 return
57 }
58 err = nil
59 l.conn = nil
60 return
61 }
62
63 func (l *oneConnListener) Close() error {
64 return nil
65 }
66
67 func (l *oneConnListener) Addr() net.Addr {
68 return dummyAddr("test-address")
69 }
70
71 func (a dummyAddr) Network() string {
72 return string(a)
73 }
74
75 func (a dummyAddr) String() string {
76 return string(a)
77 }
78
79 type noopConn struct{}
80
81 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
82 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
83 func (noopConn) SetDeadline(t time.Time) error { return nil }
84 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
85 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
86
87 type rwTestConn struct {
88 io.Reader
89 io.Writer
90 noopConn
91
92 closeFunc func() error
93 closec chan bool
94 }
95
96 func (c *rwTestConn) Close() error {
97 if c.closeFunc != nil {
98 return c.closeFunc()
99 }
100 select {
101 case c.closec <- true:
102 default:
103 }
104 return nil
105 }
106
107 type testConn struct {
108 readMu sync.Mutex
109 readBuf bytes.Buffer
110 writeBuf bytes.Buffer
111 closec chan bool
112 noopConn
113 }
114
115 func (c *testConn) Read(b []byte) (int, error) {
116 c.readMu.Lock()
117 defer c.readMu.Unlock()
118 return c.readBuf.Read(b)
119 }
120
121 func (c *testConn) Write(b []byte) (int, error) {
122 return c.writeBuf.Write(b)
123 }
124
125 func (c *testConn) Close() error {
126 select {
127 case c.closec <- true:
128 default:
129 }
130 return nil
131 }
132
133
134
135 func reqBytes(req string) []byte {
136 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
137 }
138
139 type handlerTest struct {
140 logbuf bytes.Buffer
141 handler Handler
142 }
143
144 func newHandlerTest(h Handler) handlerTest {
145 return handlerTest{handler: h}
146 }
147
148 func (ht *handlerTest) rawResponse(req string) string {
149 reqb := reqBytes(req)
150 var output bytes.Buffer
151 conn := &rwTestConn{
152 Reader: bytes.NewReader(reqb),
153 Writer: &output,
154 closec: make(chan bool, 1),
155 }
156 ln := &oneConnListener{conn: conn}
157 srv := &Server{
158 ErrorLog: log.New(&ht.logbuf, "", 0),
159 Handler: ht.handler,
160 }
161 go srv.Serve(ln)
162 <-conn.closec
163 return output.String()
164 }
165
166 func TestConsumingBodyOnNextConn(t *testing.T) {
167 t.Parallel()
168 defer afterTest(t)
169 conn := new(testConn)
170 for i := 0; i < 2; i++ {
171 conn.readBuf.Write([]byte(
172 "POST / HTTP/1.1\r\n" +
173 "Host: test\r\n" +
174 "Content-Length: 11\r\n" +
175 "\r\n" +
176 "foo=1&bar=1"))
177 }
178
179 reqNum := 0
180 ch := make(chan *Request)
181 servech := make(chan error)
182 listener := &oneConnListener{conn}
183 handler := func(res ResponseWriter, req *Request) {
184 reqNum++
185 ch <- req
186 }
187
188 go func() {
189 servech <- Serve(listener, HandlerFunc(handler))
190 }()
191
192 var req *Request
193 req = <-ch
194 if req == nil {
195 t.Fatal("Got nil first request.")
196 }
197 if req.Method != "POST" {
198 t.Errorf("For request #1's method, got %q; expected %q",
199 req.Method, "POST")
200 }
201
202 req = <-ch
203 if req == nil {
204 t.Fatal("Got nil first request.")
205 }
206 if req.Method != "POST" {
207 t.Errorf("For request #2's method, got %q; expected %q",
208 req.Method, "POST")
209 }
210
211 if serveerr := <-servech; serveerr != io.EOF {
212 t.Errorf("Serve returned %q; expected EOF", serveerr)
213 }
214 }
215
216 type stringHandler string
217
218 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
219 w.Header().Set("Result", string(s))
220 }
221
222 var handlers = []struct {
223 pattern string
224 msg string
225 }{
226 {"/", "Default"},
227 {"/someDir/", "someDir"},
228 {"/#/", "hash"},
229 {"someHost.com/someDir/", "someHost.com/someDir"},
230 }
231
232 var vtests = []struct {
233 url string
234 expected string
235 }{
236 {"http://localhost/someDir/apage", "someDir"},
237 {"http://localhost/%23/apage", "hash"},
238 {"http://localhost/otherDir/apage", "Default"},
239 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
240 {"http://otherHost.com/someDir/apage", "someDir"},
241 {"http://otherHost.com/aDir/apage", "Default"},
242
243 {"http://localhost/someDir", "/someDir/"},
244 {"http://localhost/%23", "/%23/"},
245 {"http://someHost.com/someDir", "/someDir/"},
246 }
247
248 func TestHostHandlers(t *testing.T) {
249 setParallel(t)
250 defer afterTest(t)
251 mux := NewServeMux()
252 for _, h := range handlers {
253 mux.Handle(h.pattern, stringHandler(h.msg))
254 }
255 ts := httptest.NewServer(mux)
256 defer ts.Close()
257
258 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
259 if err != nil {
260 t.Fatal(err)
261 }
262 defer conn.Close()
263 cc := httputil.NewClientConn(conn, nil)
264 for _, vt := range vtests {
265 var r *Response
266 var req Request
267 if req.URL, err = url.Parse(vt.url); err != nil {
268 t.Errorf("cannot parse url: %v", err)
269 continue
270 }
271 if err := cc.Write(&req); err != nil {
272 t.Errorf("writing request: %v", err)
273 continue
274 }
275 r, err := cc.Read(&req)
276 if err != nil {
277 t.Errorf("reading response: %v", err)
278 continue
279 }
280 switch r.StatusCode {
281 case StatusOK:
282 s := r.Header.Get("Result")
283 if s != vt.expected {
284 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
285 }
286 case StatusMovedPermanently:
287 s := r.Header.Get("Location")
288 if s != vt.expected {
289 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
290 }
291 default:
292 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
293 }
294 }
295 }
296
297 var serveMuxRegister = []struct {
298 pattern string
299 h Handler
300 }{
301 {"/dir/", serve(200)},
302 {"/search", serve(201)},
303 {"codesearch.google.com/search", serve(202)},
304 {"codesearch.google.com/", serve(203)},
305 {"example.com/", HandlerFunc(checkQueryStringHandler)},
306 }
307
308
309 func serve(code int) HandlerFunc {
310 return func(w ResponseWriter, r *Request) {
311 w.WriteHeader(code)
312 }
313 }
314
315
316
317
318 func checkQueryStringHandler(w ResponseWriter, r *Request) {
319 u := *r.URL
320 u.Scheme = "http"
321 u.Host = r.Host
322 u.RawQuery = ""
323 if "http://"+r.URL.RawQuery == u.String() {
324 w.WriteHeader(200)
325 } else {
326 w.WriteHeader(500)
327 }
328 }
329
330 var serveMuxTests = []struct {
331 method string
332 host string
333 path string
334 code int
335 pattern string
336 }{
337 {"GET", "google.com", "/", 404, ""},
338 {"GET", "google.com", "/dir", 301, "/dir/"},
339 {"GET", "google.com", "/dir/", 200, "/dir/"},
340 {"GET", "google.com", "/dir/file", 200, "/dir/"},
341 {"GET", "google.com", "/search", 201, "/search"},
342 {"GET", "google.com", "/search/", 404, ""},
343 {"GET", "google.com", "/search/foo", 404, ""},
344 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
345 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
346 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
347 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
348 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
349 {"GET", "images.google.com", "/search", 201, "/search"},
350 {"GET", "images.google.com", "/search/", 404, ""},
351 {"GET", "images.google.com", "/search/foo", 404, ""},
352 {"GET", "google.com", "/../search", 301, "/search"},
353 {"GET", "google.com", "/dir/..", 301, ""},
354 {"GET", "google.com", "/dir/..", 301, ""},
355 {"GET", "google.com", "/dir/./file", 301, "/dir/"},
356
357
358
359 {"CONNECT", "google.com", "/dir", 301, "/dir/"},
360 {"CONNECT", "google.com", "/../search", 404, ""},
361 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
362 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
363 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
364 }
365
366 func TestServeMuxHandler(t *testing.T) {
367 setParallel(t)
368 mux := NewServeMux()
369 for _, e := range serveMuxRegister {
370 mux.Handle(e.pattern, e.h)
371 }
372
373 for _, tt := range serveMuxTests {
374 r := &Request{
375 Method: tt.method,
376 Host: tt.host,
377 URL: &url.URL{
378 Path: tt.path,
379 },
380 }
381 h, pattern := mux.Handler(r)
382 rr := httptest.NewRecorder()
383 h.ServeHTTP(rr, r)
384 if pattern != tt.pattern || rr.Code != tt.code {
385 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
386 }
387 }
388 }
389
390
391 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
392 setParallel(t)
393 defer func() {
394 if err := recover(); err == nil {
395 t.Error("expected call to mux.HandleFunc to panic")
396 }
397 }()
398 mux := NewServeMux()
399 mux.HandleFunc("/", nil)
400 }
401
402 var serveMuxTests2 = []struct {
403 method string
404 host string
405 url string
406 code int
407 redirOk bool
408 }{
409 {"GET", "google.com", "/", 404, false},
410 {"GET", "example.com", "/test/?example.com/test/", 200, false},
411 {"GET", "example.com", "test/?example.com/test/", 200, true},
412 }
413
414
415
416 func TestServeMuxHandlerRedirects(t *testing.T) {
417 setParallel(t)
418 mux := NewServeMux()
419 for _, e := range serveMuxRegister {
420 mux.Handle(e.pattern, e.h)
421 }
422
423 for _, tt := range serveMuxTests2 {
424 tries := 1
425 turl := tt.url
426 for {
427 u, e := url.Parse(turl)
428 if e != nil {
429 t.Fatal(e)
430 }
431 r := &Request{
432 Method: tt.method,
433 Host: tt.host,
434 URL: u,
435 }
436 h, _ := mux.Handler(r)
437 rr := httptest.NewRecorder()
438 h.ServeHTTP(rr, r)
439 if rr.Code != 301 {
440 if rr.Code != tt.code {
441 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
442 }
443 break
444 }
445 if !tt.redirOk {
446 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
447 break
448 }
449 turl = rr.HeaderMap.Get("Location")
450 tries--
451 }
452 if tries < 0 {
453 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
454 }
455 }
456 }
457
458
459 func TestMuxRedirectLeadingSlashes(t *testing.T) {
460 setParallel(t)
461 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
462 for _, path := range paths {
463 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
464 if err != nil {
465 t.Errorf("%s", err)
466 }
467 mux := NewServeMux()
468 resp := httptest.NewRecorder()
469
470 mux.ServeHTTP(resp, req)
471
472 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
473 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
474 return
475 }
476
477 if code, expected := resp.Code, StatusMovedPermanently; code != expected {
478 t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
479 return
480 }
481 }
482 }
483
484
485
486
487
488 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
489 setParallel(t)
490 defer afterTest(t)
491
492 writeBackQuery := func(w ResponseWriter, r *Request) {
493 fmt.Fprintf(w, "%s", r.URL.RawQuery)
494 }
495
496 mux := NewServeMux()
497 mux.HandleFunc("/testOne", writeBackQuery)
498 mux.HandleFunc("/testTwo/", writeBackQuery)
499 mux.HandleFunc("/testThree", writeBackQuery)
500 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
501 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
502 })
503
504 ts := httptest.NewServer(mux)
505 defer ts.Close()
506
507 tests := [...]struct {
508 path string
509 method string
510 want string
511 statusOk bool
512 }{
513 0: {"/testOne?this=that", "GET", "this=that", true},
514 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
515 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
516 3: {"/testTwo?", "GET", "", true},
517 4: {"/testThree?foo", "GET", "foo", true},
518 5: {"/testThree/?foo", "GET", "foo:bar", true},
519 6: {"/testThree?foo", "CONNECT", "foo", true},
520 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
521
522
523 8: {"/testOne/foo/..?foo", "GET", "foo", true},
524 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
525 }
526
527 for i, tt := range tests {
528 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
529 res, err := ts.Client().Do(req)
530 if err != nil {
531 continue
532 }
533 slurp, _ := io.ReadAll(res.Body)
534 res.Body.Close()
535 if !tt.statusOk {
536 if got, want := res.StatusCode, 404; got != want {
537 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
538 }
539 }
540 if got, want := string(slurp), tt.want; got != want {
541 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
542 }
543 }
544 }
545
546 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
547 setParallel(t)
548 defer afterTest(t)
549
550 mux := NewServeMux()
551 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
552 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
553 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
554 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
555 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
556 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
557
558 tests := []struct {
559 method string
560 url string
561 code int
562 loc string
563 want string
564 }{
565 {"GET", "http://example.com/", 404, "", ""},
566 {"GET", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
567 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
568 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
569 {"GET", "http://example.com/pkg/baz", 301, "/pkg/baz/", ""},
570 {"GET", "http://example.com:3000/pkg/foo", 301, "/pkg/foo/", ""},
571 {"CONNECT", "http://example.com/", 404, "", ""},
572 {"CONNECT", "http://example.com:3000/", 404, "", ""},
573 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
574 {"CONNECT", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
575 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
576 {"CONNECT", "http://example.com:3000/pkg/baz", 301, "/pkg/baz/", ""},
577 {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""},
578 }
579
580 ts := httptest.NewServer(mux)
581 defer ts.Close()
582
583 for i, tt := range tests {
584 req, _ := NewRequest(tt.method, tt.url, nil)
585 w := httptest.NewRecorder()
586 mux.ServeHTTP(w, req)
587
588 if got, want := w.Code, tt.code; got != want {
589 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
590 }
591
592 if tt.code == 301 {
593 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
594 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
595 }
596 } else {
597 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
598 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
599 }
600 }
601 }
602 }
603
604 func TestShouldRedirectConcurrency(t *testing.T) {
605 setParallel(t)
606 defer afterTest(t)
607
608 mux := NewServeMux()
609 ts := httptest.NewServer(mux)
610 defer ts.Close()
611 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
612 }
613
614 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
615 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
616 func benchmarkServeMux(b *testing.B, runHandler bool) {
617 type test struct {
618 path string
619 code int
620 req *Request
621 }
622
623
624 var tests []test
625 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
626 for _, e := range endpoints {
627 for i := 200; i < 230; i++ {
628 p := fmt.Sprintf("/%s/%d/", e, i)
629 tests = append(tests, test{
630 path: p,
631 code: i,
632 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
633 })
634 }
635 }
636 mux := NewServeMux()
637 for _, tt := range tests {
638 mux.Handle(tt.path, serve(tt.code))
639 }
640
641 rw := httptest.NewRecorder()
642 b.ReportAllocs()
643 b.ResetTimer()
644 for i := 0; i < b.N; i++ {
645 for _, tt := range tests {
646 *rw = httptest.ResponseRecorder{}
647 h, pattern := mux.Handler(tt.req)
648 if runHandler {
649 h.ServeHTTP(rw, tt.req)
650 if pattern != tt.path || rw.Code != tt.code {
651 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
652 }
653 }
654 }
655 }
656 }
657
658 func TestServerTimeouts(t *testing.T) {
659 setParallel(t)
660 defer afterTest(t)
661
662 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
663 for i, timeout := range tries {
664 err := testServerTimeouts(timeout)
665 if err == nil {
666 return
667 }
668 t.Logf("failed at %v: %v", timeout, err)
669 if i != len(tries)-1 {
670 t.Logf("retrying at %v ...", tries[i+1])
671 }
672 }
673 t.Fatal("all attempts failed")
674 }
675
676 func testServerTimeouts(timeout time.Duration) error {
677 reqNum := 0
678 ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
679 reqNum++
680 fmt.Fprintf(res, "req=%d", reqNum)
681 }))
682 ts.Config.ReadTimeout = timeout
683 ts.Config.WriteTimeout = timeout
684 ts.Start()
685 defer ts.Close()
686
687
688 c := ts.Client()
689 r, err := c.Get(ts.URL)
690 if err != nil {
691 return fmt.Errorf("http Get #1: %v", err)
692 }
693 got, err := io.ReadAll(r.Body)
694 expected := "req=1"
695 if string(got) != expected || err != nil {
696 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
697 string(got), err, expected)
698 }
699
700
701 t1 := time.Now()
702 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
703 if err != nil {
704 return fmt.Errorf("Dial: %v", err)
705 }
706 buf := make([]byte, 1)
707 n, err := conn.Read(buf)
708 conn.Close()
709 latency := time.Since(t1)
710 if n != 0 || err != io.EOF {
711 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
712 }
713 minLatency := timeout / 5 * 4
714 if latency < minLatency {
715 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
716 }
717
718
719
720
721 r, err = c.Get(ts.URL)
722 if err != nil {
723 return fmt.Errorf("http Get #2: %v", err)
724 }
725 got, err = io.ReadAll(r.Body)
726 r.Body.Close()
727 expected = "req=2"
728 if string(got) != expected || err != nil {
729 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
730 }
731
732 if !testing.Short() {
733 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
734 if err != nil {
735 return fmt.Errorf("long Dial: %v", err)
736 }
737 defer conn.Close()
738 go io.Copy(io.Discard, conn)
739 for i := 0; i < 5; i++ {
740 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
741 if err != nil {
742 return fmt.Errorf("on write %d: %v", i, err)
743 }
744 time.Sleep(timeout / 2)
745 }
746 }
747 return nil
748 }
749
750
751 func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) {
752 if testing.Short() {
753 t.Skip("skipping in short mode")
754 }
755 setParallel(t)
756 defer afterTest(t)
757 ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {}))
758 ts.Config.WriteTimeout = 250 * time.Millisecond
759 ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
760 ts.StartTLS()
761 defer ts.Close()
762
763 c := ts.Client()
764 if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil {
765 t.Fatal(err)
766 }
767
768 for i := 1; i <= 3; i++ {
769 req, err := NewRequest("GET", ts.URL, nil)
770 if err != nil {
771 t.Fatal(err)
772 }
773
774
775 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
776 defer cancel()
777 req = req.WithContext(ctx)
778
779 r, err := c.Do(req)
780 if ctx.Err() == context.DeadlineExceeded {
781 t.Fatalf("http2 Get #%d response timed out", i)
782 }
783 if err != nil {
784 t.Fatalf("http2 Get #%d: %v", i, err)
785 }
786 r.Body.Close()
787 if r.ProtoMajor != 2 {
788 t.Fatalf("http2 Get expected HTTP/2.0, got %q", r.Proto)
789 }
790 time.Sleep(ts.Config.WriteTimeout / 2)
791 }
792 }
793
794
795
796 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
797 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
798 for i, timeout := range tries {
799 err := testFunc(timeout)
800 if err == nil {
801 return
802 }
803 t.Logf("failed at %v: %v", timeout, err)
804 if i != len(tries)-1 {
805 t.Logf("retrying at %v ...", tries[i+1])
806 }
807 }
808 t.Fatal("all attempts failed")
809 }
810
811
812 func TestHTTP2WriteDeadlineEnforcedPerStream(t *testing.T) {
813 if testing.Short() {
814 t.Skip("skipping in short mode")
815 }
816 setParallel(t)
817 defer afterTest(t)
818 tryTimeouts(t, testHTTP2WriteDeadlineEnforcedPerStream)
819 }
820
821 func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error {
822 reqNum := 0
823 ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
824 reqNum++
825 if reqNum == 1 {
826 return
827 }
828 time.Sleep(timeout)
829 }))
830 ts.Config.WriteTimeout = timeout / 2
831 ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
832 ts.StartTLS()
833 defer ts.Close()
834
835 c := ts.Client()
836 if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil {
837 return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err)
838 }
839
840 req, err := NewRequest("GET", ts.URL, nil)
841 if err != nil {
842 return fmt.Errorf("NewRequest: %v", err)
843 }
844 r, err := c.Do(req)
845 if err != nil {
846 return fmt.Errorf("http2 Get #1: %v", err)
847 }
848 r.Body.Close()
849 if r.ProtoMajor != 2 {
850 return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto)
851 }
852
853 req, err = NewRequest("GET", ts.URL, nil)
854 if err != nil {
855 return fmt.Errorf("NewRequest: %v", err)
856 }
857 r, err = c.Do(req)
858 if err == nil {
859 r.Body.Close()
860 if r.ProtoMajor != 2 {
861 return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto)
862 }
863 return fmt.Errorf("http2 Get #2 expected error, got nil")
864 }
865 expected := "stream ID 3; INTERNAL_ERROR"
866 if !strings.Contains(err.Error(), expected) {
867 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
868 }
869 return nil
870 }
871
872
873 func TestHTTP2NoWriteDeadline(t *testing.T) {
874 if testing.Short() {
875 t.Skip("skipping in short mode")
876 }
877 setParallel(t)
878 defer afterTest(t)
879 tryTimeouts(t, testHTTP2NoWriteDeadline)
880 }
881
882 func testHTTP2NoWriteDeadline(timeout time.Duration) error {
883 reqNum := 0
884 ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
885 reqNum++
886 if reqNum == 1 {
887 return
888 }
889 time.Sleep(timeout)
890 }))
891 ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
892 ts.StartTLS()
893 defer ts.Close()
894
895 c := ts.Client()
896 if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil {
897 return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err)
898 }
899
900 for i := 0; i < 2; i++ {
901 req, err := NewRequest("GET", ts.URL, nil)
902 if err != nil {
903 return fmt.Errorf("NewRequest: %v", err)
904 }
905 r, err := c.Do(req)
906 if err != nil {
907 return fmt.Errorf("http2 Get #%d: %v", i, err)
908 }
909 r.Body.Close()
910 if r.ProtoMajor != 2 {
911 return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto)
912 }
913 }
914 return nil
915 }
916
917
918
919
920 func TestOnlyWriteTimeout(t *testing.T) {
921 setParallel(t)
922 defer afterTest(t)
923 var (
924 mu sync.RWMutex
925 conn net.Conn
926 )
927 var afterTimeoutErrc = make(chan error, 1)
928 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) {
929 buf := make([]byte, 512<<10)
930 _, err := w.Write(buf)
931 if err != nil {
932 t.Errorf("handler Write error: %v", err)
933 return
934 }
935 mu.RLock()
936 defer mu.RUnlock()
937 if conn == nil {
938 t.Error("no established connection found")
939 return
940 }
941 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
942 _, err = w.Write(buf)
943 afterTimeoutErrc <- err
944 }))
945 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
946 ts.Start()
947 defer ts.Close()
948
949 c := ts.Client()
950
951 errc := make(chan error, 1)
952 go func() {
953 res, err := c.Get(ts.URL)
954 if err != nil {
955 errc <- err
956 return
957 }
958 _, err = io.Copy(io.Discard, res.Body)
959 res.Body.Close()
960 errc <- err
961 }()
962 select {
963 case err := <-errc:
964 if err == nil {
965 t.Errorf("expected an error from Get request")
966 }
967 case <-time.After(10 * time.Second):
968 t.Fatal("timeout waiting for Get error")
969 }
970 if err := <-afterTimeoutErrc; err == nil {
971 t.Error("expected write error after timeout")
972 }
973 }
974
975
976 type trackLastConnListener struct {
977 net.Listener
978
979 mu *sync.RWMutex
980 last *net.Conn
981 }
982
983 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
984 c, err = l.Listener.Accept()
985 if err == nil {
986 l.mu.Lock()
987 *l.last = c
988 l.mu.Unlock()
989 }
990 return
991 }
992
993
994 func TestIdentityResponse(t *testing.T) {
995 setParallel(t)
996 defer afterTest(t)
997 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
998 rw.Header().Set("Content-Length", "3")
999 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
1000 switch {
1001 case req.FormValue("overwrite") == "1":
1002 _, err := rw.Write([]byte("foo TOO LONG"))
1003 if err != ErrContentLength {
1004 t.Errorf("expected ErrContentLength; got %v", err)
1005 }
1006 case req.FormValue("underwrite") == "1":
1007 rw.Header().Set("Content-Length", "500")
1008 rw.Write([]byte("too short"))
1009 default:
1010 rw.Write([]byte("foo"))
1011 }
1012 })
1013
1014 ts := httptest.NewServer(handler)
1015 defer ts.Close()
1016
1017 c := ts.Client()
1018
1019
1020
1021
1022
1023 for _, te := range []string{"", "identity"} {
1024 url := ts.URL + "/?te=" + te
1025 res, err := c.Get(url)
1026 if err != nil {
1027 t.Fatalf("error with Get of %s: %v", url, err)
1028 }
1029 if cl, expected := res.ContentLength, int64(3); cl != expected {
1030 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1031 }
1032 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1033 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1034 }
1035 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1036 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1037 url, expected, tl, res.TransferEncoding)
1038 }
1039 res.Body.Close()
1040 }
1041
1042
1043 url := ts.URL + "/?overwrite=1"
1044 res, err := c.Get(url)
1045 if err != nil {
1046 t.Fatalf("error with Get of %s: %v", url, err)
1047 }
1048 res.Body.Close()
1049
1050
1051
1052 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1053 if err != nil {
1054 t.Fatalf("error dialing: %v", err)
1055 }
1056 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1057 if err != nil {
1058 t.Fatalf("error writing: %v", err)
1059 }
1060
1061
1062 got, _ := io.ReadAll(conn)
1063 expectedSuffix := "\r\n\r\ntoo short"
1064 if !strings.HasSuffix(string(got), expectedSuffix) {
1065 t.Errorf("Expected output to end with %q; got response body %q",
1066 expectedSuffix, string(got))
1067 }
1068 }
1069
1070 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1071 setParallel(t)
1072 defer afterTest(t)
1073 s := httptest.NewServer(h)
1074 defer s.Close()
1075
1076 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1077 if err != nil {
1078 t.Fatal("dial error:", err)
1079 }
1080 defer conn.Close()
1081
1082 _, err = fmt.Fprint(conn, req)
1083 if err != nil {
1084 t.Fatal("print error:", err)
1085 }
1086
1087 r := bufio.NewReader(conn)
1088 res, err := ReadResponse(r, &Request{Method: "GET"})
1089 if err != nil {
1090 t.Fatal("ReadResponse error:", err)
1091 }
1092
1093 didReadAll := make(chan bool, 1)
1094 go func() {
1095 select {
1096 case <-time.After(5 * time.Second):
1097 t.Error("body not closed after 5s")
1098 return
1099 case <-didReadAll:
1100 }
1101 }()
1102
1103 _, err = io.ReadAll(r)
1104 if err != nil {
1105 t.Fatal("read error:", err)
1106 }
1107 didReadAll <- true
1108
1109 if !res.Close {
1110 t.Errorf("Response.Close = false; want true")
1111 }
1112 }
1113
1114 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1115 setParallel(t)
1116 defer afterTest(t)
1117 ts := httptest.NewServer(handler)
1118 defer ts.Close()
1119 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1120 if err != nil {
1121 t.Fatal(err)
1122 }
1123 defer conn.Close()
1124 br := bufio.NewReader(conn)
1125 for i := 0; i < 2; i++ {
1126 if _, err := io.WriteString(conn, req); err != nil {
1127 t.Fatal(err)
1128 }
1129 res, err := ReadResponse(br, nil)
1130 if err != nil {
1131 t.Fatalf("res %d: %v", i+1, err)
1132 }
1133 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1134 t.Fatalf("res %d body copy: %v", i+1, err)
1135 }
1136 res.Body.Close()
1137 }
1138 }
1139
1140
1141 func TestServeHTTP10Close(t *testing.T) {
1142 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1143 ServeFile(w, r, "testdata/file")
1144 }))
1145 }
1146
1147
1148 func TestClientCanClose(t *testing.T) {
1149 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1150
1151 }))
1152 }
1153
1154
1155
1156 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1157 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1158 w.Header().Set("Connection", "close")
1159 }))
1160 }
1161
1162 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1163 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1164 w.Header().Set("Connection", "close")
1165 }))
1166 }
1167
1168 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1169 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1170
1171
1172 }))
1173 }
1174
1175 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1176 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1177
1178
1179 func TestHTTP10KeepAlive204Response(t *testing.T) {
1180 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1181 }
1182
1183 func TestHTTP11KeepAlive204Response(t *testing.T) {
1184 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1185 }
1186
1187 func TestHTTP10KeepAlive304Response(t *testing.T) {
1188 testTCPConnectionStaysOpen(t,
1189 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1190 HandlerFunc(send304))
1191 }
1192
1193
1194 func TestKeepAliveFinalChunkWithEOF(t *testing.T) {
1195 setParallel(t)
1196 defer afterTest(t)
1197 cst := newClientServerTest(t, false , HandlerFunc(func(w ResponseWriter, r *Request) {
1198 w.(Flusher).Flush()
1199 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1200 }))
1201 defer cst.close()
1202 type data struct {
1203 Addr string
1204 }
1205 var addrs [2]data
1206 for i := range addrs {
1207 res, err := cst.c.Get(cst.ts.URL)
1208 if err != nil {
1209 t.Fatal(err)
1210 }
1211 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1212 t.Fatal(err)
1213 }
1214 if addrs[i].Addr == "" {
1215 t.Fatal("no address")
1216 }
1217 res.Body.Close()
1218 }
1219 if addrs[0] != addrs[1] {
1220 t.Fatalf("connection not reused")
1221 }
1222 }
1223
1224 func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) }
1225 func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) }
1226
1227 func testSetsRemoteAddr(t *testing.T, h2 bool) {
1228 setParallel(t)
1229 defer afterTest(t)
1230 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1231 fmt.Fprintf(w, "%s", r.RemoteAddr)
1232 }))
1233 defer cst.close()
1234
1235 res, err := cst.c.Get(cst.ts.URL)
1236 if err != nil {
1237 t.Fatalf("Get error: %v", err)
1238 }
1239 body, err := io.ReadAll(res.Body)
1240 if err != nil {
1241 t.Fatalf("ReadAll error: %v", err)
1242 }
1243 ip := string(body)
1244 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1245 t.Fatalf("Expected local addr; got %q", ip)
1246 }
1247 }
1248
1249 type blockingRemoteAddrListener struct {
1250 net.Listener
1251 conns chan<- net.Conn
1252 }
1253
1254 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1255 c, err := l.Listener.Accept()
1256 if err != nil {
1257 return nil, err
1258 }
1259 brac := &blockingRemoteAddrConn{
1260 Conn: c,
1261 addrs: make(chan net.Addr, 1),
1262 }
1263 l.conns <- brac
1264 return brac, nil
1265 }
1266
1267 type blockingRemoteAddrConn struct {
1268 net.Conn
1269 addrs chan net.Addr
1270 }
1271
1272 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1273 return <-c.addrs
1274 }
1275
1276
1277 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1278 defer afterTest(t)
1279 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1280 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1281 }))
1282 conns := make(chan net.Conn)
1283 ts.Listener = &blockingRemoteAddrListener{
1284 Listener: ts.Listener,
1285 conns: conns,
1286 }
1287 ts.Start()
1288 defer ts.Close()
1289
1290 c := ts.Client()
1291 c.Timeout = time.Second
1292
1293 c.Transport.(*Transport).DisableKeepAlives = true
1294
1295 fetch := func(num int, response chan<- string) {
1296 resp, err := c.Get(ts.URL)
1297 if err != nil {
1298 t.Errorf("Request %d: %v", num, err)
1299 response <- ""
1300 return
1301 }
1302 defer resp.Body.Close()
1303 body, err := io.ReadAll(resp.Body)
1304 if err != nil {
1305 t.Errorf("Request %d: %v", num, err)
1306 response <- ""
1307 return
1308 }
1309 response <- string(body)
1310 }
1311
1312
1313 response1c := make(chan string, 1)
1314 go fetch(1, response1c)
1315
1316
1317 conn1 := <-conns
1318
1319
1320 response2c := make(chan string, 1)
1321 go fetch(2, response2c)
1322 var conn2 net.Conn
1323
1324 select {
1325 case conn2 = <-conns:
1326 case <-time.After(time.Second):
1327 t.Fatal("Second Accept didn't happen")
1328 }
1329
1330
1331 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1332 IP: net.ParseIP("12.12.12.12"), Port: 12}
1333
1334
1335 response2 := <-response2c
1336 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1337 t.Fatalf("response 2 addr = %q; want %q", g, e)
1338 }
1339
1340
1341 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1342 IP: net.ParseIP("21.21.21.21"), Port: 21}
1343
1344
1345 response1 := <-response1c
1346 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1347 t.Fatalf("response 1 addr = %q; want %q", g, e)
1348 }
1349 }
1350
1351
1352
1353 func TestHeadResponses_h1(t *testing.T) { testHeadResponses(t, h1Mode) }
1354 func TestHeadResponses_h2(t *testing.T) { testHeadResponses(t, h2Mode) }
1355
1356 func testHeadResponses(t *testing.T, h2 bool) {
1357 setParallel(t)
1358 defer afterTest(t)
1359 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1360 _, err := w.Write([]byte("<html>"))
1361 if err != nil {
1362 t.Errorf("ResponseWriter.Write: %v", err)
1363 }
1364
1365
1366 _, err = io.Copy(w, strings.NewReader("789a"))
1367 if err != nil {
1368 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1369 }
1370 }))
1371 defer cst.close()
1372 res, err := cst.c.Head(cst.ts.URL)
1373 if err != nil {
1374 t.Error(err)
1375 }
1376 if len(res.TransferEncoding) > 0 {
1377 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1378 }
1379 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1380 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1381 }
1382 if v := res.ContentLength; v != 10 {
1383 t.Errorf("Content-Length: %d; want 10", v)
1384 }
1385 body, err := io.ReadAll(res.Body)
1386 if err != nil {
1387 t.Error(err)
1388 }
1389 if len(body) > 0 {
1390 t.Errorf("got unexpected body %q", string(body))
1391 }
1392 }
1393
1394 func TestTLSHandshakeTimeout(t *testing.T) {
1395 setParallel(t)
1396 defer afterTest(t)
1397 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
1398 errc := make(chanWriter, 10)
1399 ts.Config.ReadTimeout = 250 * time.Millisecond
1400 ts.Config.ErrorLog = log.New(errc, "", 0)
1401 ts.StartTLS()
1402 defer ts.Close()
1403 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1404 if err != nil {
1405 t.Fatalf("Dial: %v", err)
1406 }
1407 defer conn.Close()
1408
1409 var buf [1]byte
1410 n, err := conn.Read(buf[:])
1411 if err == nil || n != 0 {
1412 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1413 }
1414
1415 select {
1416 case v := <-errc:
1417 if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1418 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1419 }
1420 case <-time.After(5 * time.Second):
1421 t.Errorf("timeout waiting for logged error")
1422 }
1423 }
1424
1425 func TestTLSServer(t *testing.T) {
1426 setParallel(t)
1427 defer afterTest(t)
1428 ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1429 if r.TLS != nil {
1430 w.Header().Set("X-TLS-Set", "true")
1431 if r.TLS.HandshakeComplete {
1432 w.Header().Set("X-TLS-HandshakeComplete", "true")
1433 }
1434 }
1435 }))
1436 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1437 defer ts.Close()
1438
1439
1440
1441
1442
1443
1444 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1445 if err != nil {
1446 t.Fatalf("Dial: %v", err)
1447 }
1448 defer idleConn.Close()
1449
1450 if !strings.HasPrefix(ts.URL, "https://") {
1451 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1452 return
1453 }
1454 client := ts.Client()
1455 res, err := client.Get(ts.URL)
1456 if err != nil {
1457 t.Error(err)
1458 return
1459 }
1460 if res == nil {
1461 t.Errorf("got nil Response")
1462 return
1463 }
1464 defer res.Body.Close()
1465 if res.Header.Get("X-TLS-Set") != "true" {
1466 t.Errorf("expected X-TLS-Set response header")
1467 return
1468 }
1469 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1470 t.Errorf("expected X-TLS-HandshakeComplete header")
1471 }
1472 }
1473
1474 func TestServeTLS(t *testing.T) {
1475 CondSkipHTTP2(t)
1476
1477 defer afterTest(t)
1478 defer SetTestHookServerServe(nil)
1479
1480 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1481 if err != nil {
1482 t.Fatal(err)
1483 }
1484 tlsConf := &tls.Config{
1485 Certificates: []tls.Certificate{cert},
1486 }
1487
1488 ln := newLocalListener(t)
1489 defer ln.Close()
1490 addr := ln.Addr().String()
1491
1492 serving := make(chan bool, 1)
1493 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1494 serving <- true
1495 })
1496 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1497 s := &Server{
1498 Addr: addr,
1499 TLSConfig: tlsConf,
1500 Handler: handler,
1501 }
1502 errc := make(chan error, 1)
1503 go func() { errc <- s.ServeTLS(ln, "", "") }()
1504 select {
1505 case err := <-errc:
1506 t.Fatalf("ServeTLS: %v", err)
1507 case <-serving:
1508 case <-time.After(5 * time.Second):
1509 t.Fatal("timeout")
1510 }
1511
1512 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1513 InsecureSkipVerify: true,
1514 NextProtos: []string{"h2", "http/1.1"},
1515 })
1516 if err != nil {
1517 t.Fatal(err)
1518 }
1519 defer c.Close()
1520 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1521 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1522 }
1523 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1524 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1525 }
1526 }
1527
1528
1529 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1530 setParallel(t)
1531 defer afterTest(t)
1532 ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1533 t.Error("unexpected HTTPS request")
1534 }))
1535 var errBuf bytes.Buffer
1536 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1537 defer ts.Close()
1538 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1539 if err != nil {
1540 t.Fatal(err)
1541 }
1542 defer conn.Close()
1543 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1544 slurp, err := io.ReadAll(conn)
1545 if err != nil {
1546 t.Fatal(err)
1547 }
1548 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1549 if !strings.HasPrefix(string(slurp), wantPrefix) {
1550 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1551 }
1552 }
1553
1554
1555 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1556 testAutomaticHTTP2_Serve(t, nil, true)
1557 }
1558
1559 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1560 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1561 }
1562
1563 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1564 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1565 }
1566
1567 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1568 setParallel(t)
1569 defer afterTest(t)
1570 ln := newLocalListener(t)
1571 ln.Close()
1572 var s Server
1573 s.TLSConfig = tlsConf
1574 if err := s.Serve(ln); err == nil {
1575 t.Fatal("expected an error")
1576 }
1577 gotH2 := s.TLSNextProto["h2"] != nil
1578 if gotH2 != wantH2 {
1579 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1580 }
1581 }
1582
1583 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1584 setParallel(t)
1585 defer afterTest(t)
1586 ln := newLocalListener(t)
1587 ln.Close()
1588 var s Server
1589
1590
1591 s.TLSConfig = &tls.Config{
1592 NextProtos: []string{"h2"},
1593 }
1594 if err := s.Serve(ln); err == nil {
1595 t.Fatal("expected an error")
1596 }
1597 on := s.TLSNextProto["h2"] != nil
1598 if !on {
1599 t.Errorf("http2 wasn't automatically enabled")
1600 }
1601 }
1602
1603 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1604 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1605 if err != nil {
1606 t.Fatal(err)
1607 }
1608 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1609 Certificates: []tls.Certificate{cert},
1610 })
1611 }
1612
1613 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1614 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1615 if err != nil {
1616 t.Fatal(err)
1617 }
1618 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1619 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1620 return &cert, nil
1621 },
1622 })
1623 }
1624
1625 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1626 CondSkipHTTP2(t)
1627
1628 defer afterTest(t)
1629 defer SetTestHookServerServe(nil)
1630 var ok bool
1631 var s *Server
1632 const maxTries = 5
1633 var ln net.Listener
1634 Try:
1635 for try := 0; try < maxTries; try++ {
1636 ln = newLocalListener(t)
1637 addr := ln.Addr().String()
1638 ln.Close()
1639 t.Logf("Got %v", addr)
1640 lnc := make(chan net.Listener, 1)
1641 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1642 lnc <- ln
1643 })
1644 s = &Server{
1645 Addr: addr,
1646 TLSConfig: tlsConf,
1647 }
1648 errc := make(chan error, 1)
1649 go func() { errc <- s.ListenAndServeTLS("", "") }()
1650 select {
1651 case err := <-errc:
1652 t.Logf("On try #%v: %v", try+1, err)
1653 continue
1654 case ln = <-lnc:
1655 ok = true
1656 t.Logf("Listening on %v", ln.Addr().String())
1657 break Try
1658 }
1659 }
1660 if !ok {
1661 t.Fatalf("Failed to start up after %d tries", maxTries)
1662 }
1663 defer ln.Close()
1664 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1665 InsecureSkipVerify: true,
1666 NextProtos: []string{"h2", "http/1.1"},
1667 })
1668 if err != nil {
1669 t.Fatal(err)
1670 }
1671 defer c.Close()
1672 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1673 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1674 }
1675 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1676 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1677 }
1678 }
1679
1680 type serverExpectTest struct {
1681 contentLength int
1682 chunked bool
1683 expectation string
1684 readBody bool
1685 expectedResponse string
1686 }
1687
1688 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
1689 return serverExpectTest{
1690 contentLength: contentLength,
1691 expectation: expectation,
1692 readBody: readBody,
1693 expectedResponse: expectedResponse,
1694 }
1695 }
1696
1697 var serverExpectTests = []serverExpectTest{
1698
1699 expectTest(100, "100-continue", true, "100 Continue"),
1700 expectTest(100, "100-cOntInUE", true, "100 Continue"),
1701
1702
1703 expectTest(100, "", true, "200 OK"),
1704
1705
1706
1707 expectTest(100, "100-continue", false, "401 Unauthorized"),
1708
1709 expectTest(100, "", false, "401 Unauthorized"),
1710
1711
1712 expectTest(0, "a-pony", false, "417 Expectation Failed"),
1713
1714
1715 expectTest(0, "100-continue", true, "200 OK"),
1716
1717 expectTest(0, "100-continue", false, "401 Unauthorized"),
1718
1719 {
1720 expectation: "100-continue",
1721 readBody: true,
1722 chunked: true,
1723 expectedResponse: "100 Continue",
1724 },
1725 }
1726
1727
1728
1729
1730 func TestServerExpect(t *testing.T) {
1731 setParallel(t)
1732 defer afterTest(t)
1733 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1734
1735
1736
1737 if strings.Contains(r.URL.RawQuery, "readbody=true") {
1738 io.ReadAll(r.Body)
1739 w.Write([]byte("Hi"))
1740 } else {
1741 w.WriteHeader(StatusUnauthorized)
1742 }
1743 }))
1744 defer ts.Close()
1745
1746 runTest := func(test serverExpectTest) {
1747 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1748 if err != nil {
1749 t.Fatalf("Dial: %v", err)
1750 }
1751 defer conn.Close()
1752
1753
1754
1755 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
1756
1757 go func() {
1758 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
1759 if test.chunked {
1760 contentLen = "Transfer-Encoding: chunked"
1761 }
1762 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
1763 "Connection: close\r\n"+
1764 "%s\r\n"+
1765 "Expect: %s\r\nHost: foo\r\n\r\n",
1766 test.readBody, contentLen, test.expectation)
1767 if err != nil {
1768 t.Errorf("On test %#v, error writing request headers: %v", test, err)
1769 return
1770 }
1771 if writeBody {
1772 var targ io.WriteCloser = struct {
1773 io.Writer
1774 io.Closer
1775 }{
1776 conn,
1777 io.NopCloser(nil),
1778 }
1779 if test.chunked {
1780 targ = httputil.NewChunkedWriter(conn)
1781 }
1782 body := strings.Repeat("A", test.contentLength)
1783 _, err = fmt.Fprint(targ, body)
1784 if err == nil {
1785 err = targ.Close()
1786 }
1787 if err != nil {
1788 if !test.readBody {
1789
1790
1791 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
1792 return
1793 }
1794 t.Errorf("On test %#v, error writing request body: %v", test, err)
1795 }
1796 }
1797 }()
1798 bufr := bufio.NewReader(conn)
1799 line, err := bufr.ReadString('\n')
1800 if err != nil {
1801 if writeBody && !test.readBody {
1802
1803
1804
1805
1806
1807 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
1808 return
1809 }
1810 t.Fatalf("On test %#v, ReadString: %v", test, err)
1811 }
1812 if !strings.Contains(line, test.expectedResponse) {
1813 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
1814 }
1815 }
1816
1817 for _, test := range serverExpectTests {
1818 runTest(test)
1819 }
1820 }
1821
1822
1823
1824 func TestServerUnreadRequestBodyLittle(t *testing.T) {
1825 setParallel(t)
1826 defer afterTest(t)
1827 conn := new(testConn)
1828 body := strings.Repeat("x", 100<<10)
1829 conn.readBuf.Write([]byte(fmt.Sprintf(
1830 "POST / HTTP/1.1\r\n"+
1831 "Host: test\r\n"+
1832 "Content-Length: %d\r\n"+
1833 "\r\n", len(body))))
1834 conn.readBuf.Write([]byte(body))
1835
1836 done := make(chan bool)
1837
1838 readBufLen := func() int {
1839 conn.readMu.Lock()
1840 defer conn.readMu.Unlock()
1841 return conn.readBuf.Len()
1842 }
1843
1844 ls := &oneConnListener{conn}
1845 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
1846 defer close(done)
1847 if bufLen := readBufLen(); bufLen < len(body)/2 {
1848 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
1849 }
1850 rw.WriteHeader(200)
1851 rw.(Flusher).Flush()
1852 if g, e := readBufLen(), 0; g != e {
1853 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
1854 }
1855 if c := rw.Header().Get("Connection"); c != "" {
1856 t.Errorf(`Connection header = %q; want ""`, c)
1857 }
1858 }))
1859 <-done
1860 }
1861
1862
1863
1864
1865 func TestServerUnreadRequestBodyLarge(t *testing.T) {
1866 setParallel(t)
1867 if testing.Short() && testenv.Builder() == "" {
1868 t.Log("skipping in short mode")
1869 }
1870 conn := new(testConn)
1871 body := strings.Repeat("x", 1<<20)
1872 conn.readBuf.Write([]byte(fmt.Sprintf(
1873 "POST / HTTP/1.1\r\n"+
1874 "Host: test\r\n"+
1875 "Content-Length: %d\r\n"+
1876 "\r\n", len(body))))
1877 conn.readBuf.Write([]byte(body))
1878 conn.closec = make(chan bool, 1)
1879
1880 ls := &oneConnListener{conn}
1881 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
1882 if conn.readBuf.Len() < len(body)/2 {
1883 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
1884 }
1885 rw.WriteHeader(200)
1886 rw.(Flusher).Flush()
1887 if conn.readBuf.Len() < len(body)/2 {
1888 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
1889 }
1890 }))
1891 <-conn.closec
1892
1893 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
1894 t.Errorf("Expected a Connection: close header; got response: %s", res)
1895 }
1896 }
1897
1898 type handlerBodyCloseTest struct {
1899 bodySize int
1900 bodyChunked bool
1901 reqConnClose bool
1902
1903 wantEOFSearch bool
1904 wantNextReq bool
1905 }
1906
1907 func (t handlerBodyCloseTest) connectionHeader() string {
1908 if t.reqConnClose {
1909 return "Connection: close\r\n"
1910 }
1911 return ""
1912 }
1913
1914 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
1915
1916
1917 0: {
1918 bodySize: 20 << 10,
1919 bodyChunked: false,
1920 reqConnClose: false,
1921 wantEOFSearch: true,
1922 wantNextReq: true,
1923 },
1924
1925
1926
1927 1: {
1928 bodySize: 20 << 10,
1929 bodyChunked: true,
1930 reqConnClose: false,
1931 wantEOFSearch: true,
1932 wantNextReq: true,
1933 },
1934
1935
1936
1937
1938 2: {
1939 bodySize: 20 << 10,
1940 bodyChunked: false,
1941 reqConnClose: true,
1942 wantEOFSearch: false,
1943 wantNextReq: false,
1944 },
1945
1946
1947
1948
1949
1950
1951 3: {
1952 bodySize: 20 << 10,
1953 bodyChunked: true,
1954 reqConnClose: true,
1955 wantEOFSearch: true,
1956 wantNextReq: false,
1957 },
1958
1959
1960 4: {
1961 bodySize: 1 << 20,
1962 bodyChunked: false,
1963 reqConnClose: false,
1964 wantEOFSearch: false,
1965 wantNextReq: false,
1966 },
1967
1968
1969 5: {
1970 bodySize: 1 << 20,
1971 bodyChunked: true,
1972 reqConnClose: false,
1973 wantEOFSearch: true,
1974 wantNextReq: false,
1975 },
1976
1977
1978
1979
1980 6: {
1981 bodySize: 1 << 20,
1982 bodyChunked: true,
1983 reqConnClose: true,
1984 wantEOFSearch: true,
1985 wantNextReq: false,
1986 },
1987
1988
1989
1990 7: {
1991 bodySize: 1 << 20,
1992 bodyChunked: false,
1993 reqConnClose: true,
1994 wantEOFSearch: false,
1995 wantNextReq: false,
1996 },
1997 }
1998
1999 func TestHandlerBodyClose(t *testing.T) {
2000 setParallel(t)
2001 if testing.Short() && testenv.Builder() == "" {
2002 t.Skip("skipping in -short mode")
2003 }
2004 for i, tt := range handlerBodyCloseTests {
2005 testHandlerBodyClose(t, i, tt)
2006 }
2007 }
2008
2009 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2010 conn := new(testConn)
2011 body := strings.Repeat("x", tt.bodySize)
2012 if tt.bodyChunked {
2013 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2014 "Host: test\r\n" +
2015 tt.connectionHeader() +
2016 "Transfer-Encoding: chunked\r\n" +
2017 "\r\n")
2018 cw := internal.NewChunkedWriter(&conn.readBuf)
2019 io.WriteString(cw, body)
2020 cw.Close()
2021 conn.readBuf.WriteString("\r\n")
2022 } else {
2023 conn.readBuf.Write([]byte(fmt.Sprintf(
2024 "POST / HTTP/1.1\r\n"+
2025 "Host: test\r\n"+
2026 tt.connectionHeader()+
2027 "Content-Length: %d\r\n"+
2028 "\r\n", len(body))))
2029 conn.readBuf.Write([]byte(body))
2030 }
2031 if !tt.reqConnClose {
2032 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2033 }
2034 conn.closec = make(chan bool, 1)
2035
2036 readBufLen := func() int {
2037 conn.readMu.Lock()
2038 defer conn.readMu.Unlock()
2039 return conn.readBuf.Len()
2040 }
2041
2042 ls := &oneConnListener{conn}
2043 var numReqs int
2044 var size0, size1 int
2045 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2046 numReqs++
2047 if numReqs == 1 {
2048 size0 = readBufLen()
2049 req.Body.Close()
2050 size1 = readBufLen()
2051 }
2052 }))
2053 <-conn.closec
2054 if numReqs < 1 || numReqs > 2 {
2055 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2056 }
2057 didSearch := size0 != size1
2058 if didSearch != tt.wantEOFSearch {
2059 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2060 }
2061 if tt.wantNextReq && numReqs != 2 {
2062 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2063 }
2064 }
2065
2066
2067
2068 type testHandlerBodyConsumer struct {
2069 name string
2070 f func(io.ReadCloser)
2071 }
2072
2073 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2074 {"nil", func(io.ReadCloser) {}},
2075 {"close", func(r io.ReadCloser) { r.Close() }},
2076 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2077 }
2078
2079 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2080 setParallel(t)
2081 defer afterTest(t)
2082 for _, handler := range testHandlerBodyConsumers {
2083 conn := new(testConn)
2084 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2085 "Host: test\r\n" +
2086 "Transfer-Encoding: chunked\r\n" +
2087 "\r\n" +
2088 "hax\r\n" +
2089 "GET /secret HTTP/1.1\r\n" +
2090 "Host: test\r\n" +
2091 "\r\n")
2092
2093 conn.closec = make(chan bool, 1)
2094 ls := &oneConnListener{conn}
2095 var numReqs int
2096 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2097 numReqs++
2098 if strings.Contains(req.URL.Path, "secret") {
2099 t.Error("Request for /secret encountered, should not have happened.")
2100 }
2101 handler.f(req.Body)
2102 }))
2103 <-conn.closec
2104 if numReqs != 1 {
2105 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2106 }
2107 }
2108 }
2109
2110 func TestInvalidTrailerClosesConnection(t *testing.T) {
2111 setParallel(t)
2112 defer afterTest(t)
2113 for _, handler := range testHandlerBodyConsumers {
2114 conn := new(testConn)
2115 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2116 "Host: test\r\n" +
2117 "Trailer: hack\r\n" +
2118 "Transfer-Encoding: chunked\r\n" +
2119 "\r\n" +
2120 "3\r\n" +
2121 "hax\r\n" +
2122 "0\r\n" +
2123 "I'm not a valid trailer\r\n" +
2124 "GET /secret HTTP/1.1\r\n" +
2125 "Host: test\r\n" +
2126 "\r\n")
2127
2128 conn.closec = make(chan bool, 1)
2129 ln := &oneConnListener{conn}
2130 var numReqs int
2131 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2132 numReqs++
2133 if strings.Contains(req.URL.Path, "secret") {
2134 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2135 }
2136 handler.f(req.Body)
2137 }))
2138 <-conn.closec
2139 if numReqs != 1 {
2140 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2141 }
2142 }
2143 }
2144
2145
2146
2147
2148 type slowTestConn struct {
2149
2150 script []any
2151 closec chan bool
2152
2153 mu sync.Mutex
2154 rd, wd time.Time
2155 noopConn
2156 }
2157
2158 func (c *slowTestConn) SetDeadline(t time.Time) error {
2159 c.SetReadDeadline(t)
2160 c.SetWriteDeadline(t)
2161 return nil
2162 }
2163
2164 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2165 c.mu.Lock()
2166 defer c.mu.Unlock()
2167 c.rd = t
2168 return nil
2169 }
2170
2171 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2172 c.mu.Lock()
2173 defer c.mu.Unlock()
2174 c.wd = t
2175 return nil
2176 }
2177
2178 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2179 c.mu.Lock()
2180 defer c.mu.Unlock()
2181 restart:
2182 if !c.rd.IsZero() && time.Now().After(c.rd) {
2183 return 0, syscall.ETIMEDOUT
2184 }
2185 if len(c.script) == 0 {
2186 return 0, io.EOF
2187 }
2188
2189 switch cue := c.script[0].(type) {
2190 case time.Duration:
2191 if !c.rd.IsZero() {
2192
2193
2194 if remaining := time.Until(c.rd); remaining < cue {
2195 c.script[0] = cue - remaining
2196 time.Sleep(remaining)
2197 return 0, syscall.ETIMEDOUT
2198 }
2199 }
2200 c.script = c.script[1:]
2201 time.Sleep(cue)
2202 goto restart
2203
2204 case string:
2205 n = copy(b, cue)
2206
2207 if len(cue) > n {
2208 c.script[0] = cue[n:]
2209 } else {
2210 c.script = c.script[1:]
2211 }
2212
2213 default:
2214 panic("unknown cue in slowTestConn script")
2215 }
2216
2217 return
2218 }
2219
2220 func (c *slowTestConn) Close() error {
2221 select {
2222 case c.closec <- true:
2223 default:
2224 }
2225 return nil
2226 }
2227
2228 func (c *slowTestConn) Write(b []byte) (int, error) {
2229 if !c.wd.IsZero() && time.Now().After(c.wd) {
2230 return 0, syscall.ETIMEDOUT
2231 }
2232 return len(b), nil
2233 }
2234
2235 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2236 if testing.Short() {
2237 t.Skip("skipping in -short mode")
2238 }
2239 defer afterTest(t)
2240 for _, handler := range testHandlerBodyConsumers {
2241 conn := &slowTestConn{
2242 script: []any{
2243 "POST /public HTTP/1.1\r\n" +
2244 "Host: test\r\n" +
2245 "Content-Length: 10000\r\n" +
2246 "\r\n",
2247 "foo bar baz",
2248 600 * time.Millisecond,
2249 "GET /secret HTTP/1.1\r\n" +
2250 "Host: test\r\n" +
2251 "\r\n",
2252 },
2253 closec: make(chan bool, 1),
2254 }
2255 ls := &oneConnListener{conn}
2256
2257 var numReqs int
2258 s := Server{
2259 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2260 numReqs++
2261 if strings.Contains(req.URL.Path, "secret") {
2262 t.Error("Request for /secret encountered, should not have happened.")
2263 }
2264 handler.f(req.Body)
2265 }),
2266 ReadTimeout: 400 * time.Millisecond,
2267 }
2268 go s.Serve(ls)
2269 <-conn.closec
2270
2271 if numReqs != 1 {
2272 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2273 }
2274 }
2275 }
2276
2277
2278 type cancelableTimeoutContext struct {
2279 context.Context
2280 }
2281
2282 func (c cancelableTimeoutContext) Err() error {
2283 if c.Context.Err() != nil {
2284 return context.DeadlineExceeded
2285 }
2286 return nil
2287 }
2288
2289 func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) }
2290 func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) }
2291 func testTimeoutHandler(t *testing.T, h2 bool) {
2292 setParallel(t)
2293 defer afterTest(t)
2294 sendHi := make(chan bool, 1)
2295 writeErrors := make(chan error, 1)
2296 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2297 <-sendHi
2298 _, werr := w.Write([]byte("hi"))
2299 writeErrors <- werr
2300 })
2301 ctx, cancel := context.WithCancel(context.Background())
2302 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2303 cst := newClientServerTest(t, h2, h)
2304 defer cst.close()
2305
2306
2307 sendHi <- true
2308 res, err := cst.c.Get(cst.ts.URL)
2309 if err != nil {
2310 t.Error(err)
2311 }
2312 if g, e := res.StatusCode, StatusOK; g != e {
2313 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2314 }
2315 body, _ := io.ReadAll(res.Body)
2316 if g, e := string(body), "hi"; g != e {
2317 t.Errorf("got body %q; expected %q", g, e)
2318 }
2319 if g := <-writeErrors; g != nil {
2320 t.Errorf("got unexpected Write error on first request: %v", g)
2321 }
2322
2323
2324 cancel()
2325
2326 res, err = cst.c.Get(cst.ts.URL)
2327 if err != nil {
2328 t.Error(err)
2329 }
2330 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2331 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2332 }
2333 body, _ = io.ReadAll(res.Body)
2334 if !strings.Contains(string(body), "<title>Timeout</title>") {
2335 t.Errorf("expected timeout body; got %q", string(body))
2336 }
2337 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2338 t.Errorf("response content-type = %q; want %q", g, w)
2339 }
2340
2341
2342
2343 sendHi <- true
2344 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2345 t.Errorf("expected Write error of %v; got %v", e, g)
2346 }
2347 }
2348
2349
2350 func TestTimeoutHandlerRace(t *testing.T) {
2351 setParallel(t)
2352 defer afterTest(t)
2353
2354 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2355 ms, _ := strconv.Atoi(r.URL.Path[1:])
2356 if ms == 0 {
2357 ms = 1
2358 }
2359 for i := 0; i < ms; i++ {
2360 w.Write([]byte("hi"))
2361 time.Sleep(time.Millisecond)
2362 }
2363 })
2364
2365 ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, ""))
2366 defer ts.Close()
2367
2368 c := ts.Client()
2369
2370 var wg sync.WaitGroup
2371 gate := make(chan bool, 10)
2372 n := 50
2373 if testing.Short() {
2374 n = 10
2375 gate = make(chan bool, 3)
2376 }
2377 for i := 0; i < n; i++ {
2378 gate <- true
2379 wg.Add(1)
2380 go func() {
2381 defer wg.Done()
2382 defer func() { <-gate }()
2383 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2384 if err == nil {
2385 io.Copy(io.Discard, res.Body)
2386 res.Body.Close()
2387 }
2388 }()
2389 }
2390 wg.Wait()
2391 }
2392
2393
2394
2395 func TestTimeoutHandlerRaceHeader(t *testing.T) {
2396 setParallel(t)
2397 defer afterTest(t)
2398
2399 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2400 w.WriteHeader(204)
2401 })
2402
2403 ts := httptest.NewServer(TimeoutHandler(delay204, time.Nanosecond, ""))
2404 defer ts.Close()
2405
2406 var wg sync.WaitGroup
2407 gate := make(chan bool, 50)
2408 n := 500
2409 if testing.Short() {
2410 n = 10
2411 }
2412
2413 c := ts.Client()
2414 for i := 0; i < n; i++ {
2415 gate <- true
2416 wg.Add(1)
2417 go func() {
2418 defer wg.Done()
2419 defer func() { <-gate }()
2420 res, err := c.Get(ts.URL)
2421 if err != nil {
2422
2423
2424 t.Log(err)
2425 return
2426 }
2427 defer res.Body.Close()
2428 io.Copy(io.Discard, res.Body)
2429 }()
2430 }
2431 wg.Wait()
2432 }
2433
2434
2435 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
2436 setParallel(t)
2437 defer afterTest(t)
2438 sendHi := make(chan bool, 1)
2439 writeErrors := make(chan error, 1)
2440 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2441 w.Header().Set("Content-Type", "text/plain")
2442 <-sendHi
2443 _, werr := w.Write([]byte("hi"))
2444 writeErrors <- werr
2445 })
2446 ctx, cancel := context.WithCancel(context.Background())
2447 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2448 cst := newClientServerTest(t, h1Mode, h)
2449 defer cst.close()
2450
2451
2452 sendHi <- true
2453 res, err := cst.c.Get(cst.ts.URL)
2454 if err != nil {
2455 t.Error(err)
2456 }
2457 if g, e := res.StatusCode, StatusOK; g != e {
2458 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2459 }
2460 body, _ := io.ReadAll(res.Body)
2461 if g, e := string(body), "hi"; g != e {
2462 t.Errorf("got body %q; expected %q", g, e)
2463 }
2464 if g := <-writeErrors; g != nil {
2465 t.Errorf("got unexpected Write error on first request: %v", g)
2466 }
2467
2468
2469 cancel()
2470
2471 res, err = cst.c.Get(cst.ts.URL)
2472 if err != nil {
2473 t.Error(err)
2474 }
2475 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2476 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2477 }
2478 body, _ = io.ReadAll(res.Body)
2479 if !strings.Contains(string(body), "<title>Timeout</title>") {
2480 t.Errorf("expected timeout body; got %q", string(body))
2481 }
2482
2483
2484
2485 sendHi <- true
2486 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2487 t.Errorf("expected Write error of %v; got %v", e, g)
2488 }
2489 }
2490
2491
2492 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2493 if testing.Short() {
2494 t.Skip("skipping sleeping test in -short mode")
2495 }
2496 defer afterTest(t)
2497 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2498 w.WriteHeader(StatusNoContent)
2499 }
2500 timeout := 300 * time.Millisecond
2501 ts := httptest.NewServer(TimeoutHandler(handler, timeout, ""))
2502 defer ts.Close()
2503
2504 c := ts.Client()
2505
2506
2507
2508
2509 time.Sleep(2 * timeout)
2510 res, err := c.Get(ts.URL)
2511 if err != nil {
2512 t.Fatal(err)
2513 }
2514 defer res.Body.Close()
2515 if res.StatusCode != StatusNoContent {
2516 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2517 }
2518 }
2519
2520 func TestTimeoutHandlerContextCanceled(t *testing.T) {
2521 setParallel(t)
2522 defer afterTest(t)
2523 writeErrors := make(chan error, 1)
2524 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2525 w.Header().Set("Content-Type", "text/plain")
2526 var err error
2527
2528
2529
2530 for i := 0; i < 100; i++ {
2531 _, err = w.Write([]byte("a"))
2532 if err != nil {
2533 break
2534 }
2535 time.Sleep(1 * time.Millisecond)
2536 }
2537 writeErrors <- err
2538 })
2539 ctx, cancel := context.WithCancel(context.Background())
2540 cancel()
2541 h := NewTestTimeoutHandler(sayHi, ctx)
2542 cst := newClientServerTest(t, h1Mode, h)
2543 defer cst.close()
2544
2545 res, err := cst.c.Get(cst.ts.URL)
2546 if err != nil {
2547 t.Error(err)
2548 }
2549 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2550 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2551 }
2552 body, _ := io.ReadAll(res.Body)
2553 if g, e := string(body), ""; g != e {
2554 t.Errorf("got body %q; expected %q", g, e)
2555 }
2556 if g, e := <-writeErrors, context.Canceled; g != e {
2557 t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
2558 }
2559 }
2560
2561
2562 func TestTimeoutHandlerEmptyResponse(t *testing.T) {
2563 setParallel(t)
2564 defer afterTest(t)
2565 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2566
2567 }
2568 timeout := 300 * time.Millisecond
2569 ts := httptest.NewServer(TimeoutHandler(handler, timeout, ""))
2570 defer ts.Close()
2571
2572 c := ts.Client()
2573
2574 res, err := c.Get(ts.URL)
2575 if err != nil {
2576 t.Fatal(err)
2577 }
2578 defer res.Body.Close()
2579 if res.StatusCode != StatusOK {
2580 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2581 }
2582 }
2583
2584
2585 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2586 wrapper := func(h Handler) Handler {
2587 return TimeoutHandler(h, time.Second, "")
2588 }
2589 testHandlerPanic(t, false, false, wrapper, "intentional death for testing")
2590 }
2591
2592 func TestRedirectBadPath(t *testing.T) {
2593
2594
2595 rr := httptest.NewRecorder()
2596 req := &Request{
2597 Method: "GET",
2598 URL: &url.URL{
2599 Scheme: "http",
2600 Path: "not-empty-but-no-leading-slash",
2601 },
2602 }
2603 Redirect(rr, req, "", 304)
2604 if rr.Code != 304 {
2605 t.Errorf("Code = %d; want 304", rr.Code)
2606 }
2607 }
2608
2609
2610 func TestRedirect(t *testing.T) {
2611 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2612
2613 var tests = []struct {
2614 in string
2615 want string
2616 }{
2617
2618 {"http://foobar.com/baz", "http://foobar.com/baz"},
2619
2620 {"https://foobar.com/baz", "https://foobar.com/baz"},
2621
2622 {"test://foobar.com/baz", "test://foobar.com/baz"},
2623
2624 {"//foobar.com/baz", "//foobar.com/baz"},
2625
2626 {"/foobar.com/baz", "/foobar.com/baz"},
2627
2628 {"foobar.com/baz", "/qux/foobar.com/baz"},
2629
2630 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2631
2632 {"///foobar.com/baz", "/foobar.com/baz"},
2633
2634
2635 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2636 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2637 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2638
2639 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2640 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2641 }
2642
2643 for _, tt := range tests {
2644 rec := httptest.NewRecorder()
2645 Redirect(rec, req, tt.in, 302)
2646 if got, want := rec.Code, 302; got != want {
2647 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
2648 }
2649 if got := rec.Header().Get("Location"); got != tt.want {
2650 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
2651 }
2652 }
2653 }
2654
2655
2656
2657 func TestRedirectContentTypeAndBody(t *testing.T) {
2658 type ctHeader struct {
2659 Values []string
2660 }
2661
2662 var tests = []struct {
2663 method string
2664 ct *ctHeader
2665 wantCT string
2666 wantBody string
2667 }{
2668 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
2669 {MethodHead, nil, "text/html; charset=utf-8", ""},
2670 {MethodPost, nil, "", ""},
2671 {MethodDelete, nil, "", ""},
2672 {"foo", nil, "", ""},
2673 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
2674 {MethodGet, &ctHeader{[]string{}}, "", ""},
2675 {MethodGet, &ctHeader{nil}, "", ""},
2676 }
2677 for _, tt := range tests {
2678 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
2679 rec := httptest.NewRecorder()
2680 if tt.ct != nil {
2681 rec.Header()["Content-Type"] = tt.ct.Values
2682 }
2683 Redirect(rec, req, "/foo", 302)
2684 if got, want := rec.Code, 302; got != want {
2685 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
2686 }
2687 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
2688 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
2689 }
2690 resp := rec.Result()
2691 body, err := io.ReadAll(resp.Body)
2692 if err != nil {
2693 t.Fatal(err)
2694 }
2695 if got, want := string(body), tt.wantBody; got != want {
2696 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
2697 }
2698 }
2699 }
2700
2701
2702
2703
2704
2705
2706
2707 func TestZeroLengthPostAndResponse_h1(t *testing.T) {
2708 testZeroLengthPostAndResponse(t, h1Mode)
2709 }
2710 func TestZeroLengthPostAndResponse_h2(t *testing.T) {
2711 testZeroLengthPostAndResponse(t, h2Mode)
2712 }
2713
2714 func testZeroLengthPostAndResponse(t *testing.T, h2 bool) {
2715 setParallel(t)
2716 defer afterTest(t)
2717 cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) {
2718 all, err := io.ReadAll(r.Body)
2719 if err != nil {
2720 t.Fatalf("handler ReadAll: %v", err)
2721 }
2722 if len(all) != 0 {
2723 t.Errorf("handler got %d bytes; expected 0", len(all))
2724 }
2725 rw.Header().Set("Content-Length", "0")
2726 }))
2727 defer cst.close()
2728
2729 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
2730 if err != nil {
2731 t.Fatal(err)
2732 }
2733 req.ContentLength = 0
2734
2735 var resp [5]*Response
2736 for i := range resp {
2737 resp[i], err = cst.c.Do(req)
2738 if err != nil {
2739 t.Fatalf("client post #%d: %v", i, err)
2740 }
2741 }
2742
2743 for i := range resp {
2744 all, err := io.ReadAll(resp[i].Body)
2745 if err != nil {
2746 t.Fatalf("req #%d: client ReadAll: %v", i, err)
2747 }
2748 if len(all) != 0 {
2749 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
2750 }
2751 }
2752 }
2753
2754 func TestHandlerPanicNil_h1(t *testing.T) { testHandlerPanic(t, false, h1Mode, nil, nil) }
2755 func TestHandlerPanicNil_h2(t *testing.T) { testHandlerPanic(t, false, h2Mode, nil, nil) }
2756
2757 func TestHandlerPanic_h1(t *testing.T) {
2758 testHandlerPanic(t, false, h1Mode, nil, "intentional death for testing")
2759 }
2760 func TestHandlerPanic_h2(t *testing.T) {
2761 testHandlerPanic(t, false, h2Mode, nil, "intentional death for testing")
2762 }
2763
2764 func TestHandlerPanicWithHijack(t *testing.T) {
2765
2766 testHandlerPanic(t, true, h1Mode, nil, "intentional death for testing")
2767 }
2768
2769 func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue any) {
2770 defer afterTest(t)
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787 pr, pw := io.Pipe()
2788 log.SetOutput(pw)
2789 defer log.SetOutput(os.Stderr)
2790 defer pw.Close()
2791
2792 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
2793 if withHijack {
2794 rwc, _, err := w.(Hijacker).Hijack()
2795 if err != nil {
2796 t.Logf("unexpected error: %v", err)
2797 }
2798 defer rwc.Close()
2799 }
2800 panic(panicValue)
2801 })
2802 if wrapper != nil {
2803 handler = wrapper(handler)
2804 }
2805 cst := newClientServerTest(t, h2, handler)
2806 defer cst.close()
2807
2808
2809
2810
2811 done := make(chan bool, 1)
2812 go func() {
2813 buf := make([]byte, 4<<10)
2814 _, err := pr.Read(buf)
2815 pr.Close()
2816 if err != nil && err != io.EOF {
2817 t.Error(err)
2818 }
2819 done <- true
2820 }()
2821
2822 _, err := cst.c.Get(cst.ts.URL)
2823 if err == nil {
2824 t.Logf("expected an error")
2825 }
2826
2827 if panicValue == nil {
2828 return
2829 }
2830
2831 select {
2832 case <-done:
2833 return
2834 case <-time.After(5 * time.Second):
2835 t.Fatal("expected server handler to log an error")
2836 }
2837 }
2838
2839 type terrorWriter struct{ t *testing.T }
2840
2841 func (w terrorWriter) Write(p []byte) (int, error) {
2842 w.t.Errorf("%s", p)
2843 return len(p), nil
2844 }
2845
2846
2847
2848 func TestServerWriteHijackZeroBytes(t *testing.T) {
2849 defer afterTest(t)
2850 done := make(chan struct{})
2851 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
2852 defer close(done)
2853 w.(Flusher).Flush()
2854 conn, _, err := w.(Hijacker).Hijack()
2855 if err != nil {
2856 t.Errorf("Hijack: %v", err)
2857 return
2858 }
2859 defer conn.Close()
2860 _, err = w.Write(nil)
2861 if err != ErrHijacked {
2862 t.Errorf("Write error = %v; want ErrHijacked", err)
2863 }
2864 }))
2865 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
2866 ts.Start()
2867 defer ts.Close()
2868
2869 c := ts.Client()
2870 res, err := c.Get(ts.URL)
2871 if err != nil {
2872 t.Fatal(err)
2873 }
2874 res.Body.Close()
2875 select {
2876 case <-done:
2877 case <-time.After(5 * time.Second):
2878 t.Fatal("timeout")
2879 }
2880 }
2881
2882 func TestServerNoDate_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Date") }
2883 func TestServerNoDate_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Date") }
2884 func TestServerNoContentType_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Content-Type") }
2885 func TestServerNoContentType_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Content-Type") }
2886
2887 func testServerNoHeader(t *testing.T, h2 bool, header string) {
2888 setParallel(t)
2889 defer afterTest(t)
2890 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
2891 w.Header()[header] = nil
2892 io.WriteString(w, "<html>foo</html>")
2893 }))
2894 defer cst.close()
2895 res, err := cst.c.Get(cst.ts.URL)
2896 if err != nil {
2897 t.Fatal(err)
2898 }
2899 res.Body.Close()
2900 if got, ok := res.Header[header]; ok {
2901 t.Fatalf("Expected no %s header; got %q", header, got)
2902 }
2903 }
2904
2905 func TestStripPrefix(t *testing.T) {
2906 setParallel(t)
2907 defer afterTest(t)
2908 h := HandlerFunc(func(w ResponseWriter, r *Request) {
2909 w.Header().Set("X-Path", r.URL.Path)
2910 w.Header().Set("X-RawPath", r.URL.RawPath)
2911 })
2912 ts := httptest.NewServer(StripPrefix("/foo/bar", h))
2913 defer ts.Close()
2914
2915 c := ts.Client()
2916
2917 cases := []struct {
2918 reqPath string
2919 path string
2920 rawPath string
2921 }{
2922 {"/foo/bar/qux", "/qux", ""},
2923 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
2924 {"/foo%2Fbar/qux", "", ""},
2925 {"/bar", "", ""},
2926 }
2927 for _, tc := range cases {
2928 t.Run(tc.reqPath, func(t *testing.T) {
2929 res, err := c.Get(ts.URL + tc.reqPath)
2930 if err != nil {
2931 t.Fatal(err)
2932 }
2933 res.Body.Close()
2934 if tc.path == "" {
2935 if res.StatusCode != StatusNotFound {
2936 t.Errorf("got %q, want 404 Not Found", res.Status)
2937 }
2938 return
2939 }
2940 if res.StatusCode != StatusOK {
2941 t.Fatalf("got %q, want 200 OK", res.Status)
2942 }
2943 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
2944 t.Errorf("got Path %q, want %q", g, w)
2945 }
2946 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
2947 t.Errorf("got RawPath %q, want %q", g, w)
2948 }
2949 })
2950 }
2951 }
2952
2953
2954 func TestStripPrefixNotModifyRequest(t *testing.T) {
2955 h := StripPrefix("/foo", NotFoundHandler())
2956 req := httptest.NewRequest("GET", "/foo/bar", nil)
2957 h.ServeHTTP(httptest.NewRecorder(), req)
2958 if req.URL.Path != "/foo/bar" {
2959 t.Errorf("StripPrefix should not modify the provided Request, but it did")
2960 }
2961 }
2962
2963 func TestRequestLimit_h1(t *testing.T) { testRequestLimit(t, h1Mode) }
2964 func TestRequestLimit_h2(t *testing.T) { testRequestLimit(t, h2Mode) }
2965 func testRequestLimit(t *testing.T, h2 bool) {
2966 setParallel(t)
2967 defer afterTest(t)
2968 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
2969 t.Fatalf("didn't expect to get request in Handler")
2970 }), optQuietLog)
2971 defer cst.close()
2972 req, _ := NewRequest("GET", cst.ts.URL, nil)
2973 var bytesPerHeader = len("header12345: val12345\r\n")
2974 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
2975 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
2976 }
2977 res, err := cst.c.Do(req)
2978 if res != nil {
2979 defer res.Body.Close()
2980 }
2981 if h2 {
2982
2983
2984
2985
2986 if err == nil && res.StatusCode != 431 {
2987 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
2988 }
2989 } else {
2990
2991
2992
2993
2994 if err != nil {
2995 t.Fatalf("Do: %v", err)
2996 }
2997 if res.StatusCode != 431 {
2998 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
2999 }
3000 }
3001 }
3002
3003 type neverEnding byte
3004
3005 func (b neverEnding) Read(p []byte) (n int, err error) {
3006 for i := range p {
3007 p[i] = byte(b)
3008 }
3009 return len(p), nil
3010 }
3011
3012 type countReader struct {
3013 r io.Reader
3014 n *int64
3015 }
3016
3017 func (cr countReader) Read(p []byte) (n int, err error) {
3018 n, err = cr.r.Read(p)
3019 atomic.AddInt64(cr.n, int64(n))
3020 return
3021 }
3022
3023 func TestRequestBodyLimit_h1(t *testing.T) { testRequestBodyLimit(t, h1Mode) }
3024 func TestRequestBodyLimit_h2(t *testing.T) { testRequestBodyLimit(t, h2Mode) }
3025 func testRequestBodyLimit(t *testing.T, h2 bool) {
3026 setParallel(t)
3027 defer afterTest(t)
3028 const limit = 1 << 20
3029 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
3030 r.Body = MaxBytesReader(w, r.Body, limit)
3031 n, err := io.Copy(io.Discard, r.Body)
3032 if err == nil {
3033 t.Errorf("expected error from io.Copy")
3034 }
3035 if n != limit {
3036 t.Errorf("io.Copy = %d, want %d", n, limit)
3037 }
3038 }))
3039 defer cst.close()
3040
3041 nWritten := new(int64)
3042 req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200))
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053 _, _ = cst.c.Do(req)
3054
3055 if atomic.LoadInt64(nWritten) > limit*100 {
3056 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
3057 limit, nWritten)
3058 }
3059 }
3060
3061
3062
3063 func TestClientWriteShutdown(t *testing.T) {
3064 if runtime.GOOS == "plan9" {
3065 t.Skip("skipping test; see https://golang.org/issue/17906")
3066 }
3067 defer afterTest(t)
3068 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
3069 defer ts.Close()
3070 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3071 if err != nil {
3072 t.Fatalf("Dial: %v", err)
3073 }
3074 err = conn.(*net.TCPConn).CloseWrite()
3075 if err != nil {
3076 t.Fatalf("CloseWrite: %v", err)
3077 }
3078
3079 bs, err := io.ReadAll(conn)
3080 if err != nil {
3081 t.Errorf("ReadAll: %v", err)
3082 }
3083 got := string(bs)
3084 if got != "" {
3085 t.Errorf("read %q from server; want nothing", got)
3086 }
3087 }
3088
3089
3090
3091 func TestServerBufferedChunking(t *testing.T) {
3092 conn := new(testConn)
3093 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3094 conn.closec = make(chan bool, 1)
3095 ls := &oneConnListener{conn}
3096 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3097 rw.(Flusher).Flush()
3098 rw.Write([]byte{'x'})
3099 rw.Write([]byte{'y'})
3100 rw.Write([]byte{'z'})
3101 }))
3102 <-conn.closec
3103 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3104 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3105 conn.writeBuf.Bytes())
3106 }
3107 }
3108
3109
3110
3111
3112
3113 func TestServerGracefulClose(t *testing.T) {
3114 setParallel(t)
3115 defer afterTest(t)
3116 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3117 Error(w, "bye", StatusUnauthorized)
3118 }))
3119 defer ts.Close()
3120
3121 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3122 if err != nil {
3123 t.Fatal(err)
3124 }
3125 defer conn.Close()
3126 const bodySize = 5 << 20
3127 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3128 for i := 0; i < bodySize; i++ {
3129 req = append(req, 'x')
3130 }
3131 writeErr := make(chan error)
3132 go func() {
3133 _, err := conn.Write(req)
3134 writeErr <- err
3135 }()
3136 br := bufio.NewReader(conn)
3137 lineNum := 0
3138 for {
3139 line, err := br.ReadString('\n')
3140 if err == io.EOF {
3141 break
3142 }
3143 if err != nil {
3144 t.Fatalf("ReadLine: %v", err)
3145 }
3146 lineNum++
3147 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3148 t.Errorf("Response line = %q; want a 401", line)
3149 }
3150 }
3151
3152
3153
3154 <-writeErr
3155 }
3156
3157 func TestCaseSensitiveMethod_h1(t *testing.T) { testCaseSensitiveMethod(t, h1Mode) }
3158 func TestCaseSensitiveMethod_h2(t *testing.T) { testCaseSensitiveMethod(t, h2Mode) }
3159 func testCaseSensitiveMethod(t *testing.T, h2 bool) {
3160 defer afterTest(t)
3161 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
3162 if r.Method != "get" {
3163 t.Errorf(`Got method %q; want "get"`, r.Method)
3164 }
3165 }))
3166 defer cst.close()
3167 req, _ := NewRequest("get", cst.ts.URL, nil)
3168 res, err := cst.c.Do(req)
3169 if err != nil {
3170 t.Error(err)
3171 return
3172 }
3173
3174 res.Body.Close()
3175 }
3176
3177
3178
3179
3180
3181 func TestContentLengthZero(t *testing.T) {
3182 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {}))
3183 defer ts.Close()
3184
3185 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3186 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3187 if err != nil {
3188 t.Fatalf("error dialing: %v", err)
3189 }
3190 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3191 if err != nil {
3192 t.Fatalf("error writing: %v", err)
3193 }
3194 req, _ := NewRequest("GET", "/", nil)
3195 res, err := ReadResponse(bufio.NewReader(conn), req)
3196 if err != nil {
3197 t.Fatalf("error reading response: %v", err)
3198 }
3199 if te := res.TransferEncoding; len(te) > 0 {
3200 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3201 }
3202 if cl := res.ContentLength; cl != 0 {
3203 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3204 }
3205 conn.Close()
3206 }
3207 }
3208
3209 func TestCloseNotifier(t *testing.T) {
3210 defer afterTest(t)
3211 gotReq := make(chan bool, 1)
3212 sawClose := make(chan bool, 1)
3213 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
3214 gotReq <- true
3215 cc := rw.(CloseNotifier).CloseNotify()
3216 <-cc
3217 sawClose <- true
3218 }))
3219 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3220 if err != nil {
3221 t.Fatalf("error dialing: %v", err)
3222 }
3223 diec := make(chan bool)
3224 go func() {
3225 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3226 if err != nil {
3227 t.Error(err)
3228 return
3229 }
3230 <-diec
3231 conn.Close()
3232 }()
3233 For:
3234 for {
3235 select {
3236 case <-gotReq:
3237 diec <- true
3238 case <-sawClose:
3239 break For
3240 case <-time.After(5 * time.Second):
3241 t.Fatal("timeout")
3242 }
3243 }
3244 ts.Close()
3245 }
3246
3247
3248
3249
3250
3251 func TestCloseNotifierPipelined(t *testing.T) {
3252 setParallel(t)
3253 defer afterTest(t)
3254 gotReq := make(chan bool, 2)
3255 sawClose := make(chan bool, 2)
3256 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
3257 gotReq <- true
3258 cc := rw.(CloseNotifier).CloseNotify()
3259 select {
3260 case <-cc:
3261 t.Error("unexpected CloseNotify")
3262 case <-time.After(100 * time.Millisecond):
3263 }
3264 sawClose <- true
3265 }))
3266 defer ts.Close()
3267 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3268 if err != nil {
3269 t.Fatalf("error dialing: %v", err)
3270 }
3271 diec := make(chan bool, 1)
3272 defer close(diec)
3273 go func() {
3274 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3275 _, err = io.WriteString(conn, req+req)
3276 if err != nil {
3277 t.Error(err)
3278 return
3279 }
3280 <-diec
3281 conn.Close()
3282 }()
3283 reqs := 0
3284 closes := 0
3285 for {
3286 select {
3287 case <-gotReq:
3288 reqs++
3289 if reqs > 2 {
3290 t.Fatal("too many requests")
3291 }
3292 case <-sawClose:
3293 closes++
3294 if closes > 1 {
3295 return
3296 }
3297 case <-time.After(5 * time.Second):
3298 ts.CloseClientConnections()
3299 t.Fatal("timeout")
3300 }
3301 }
3302 }
3303
3304 func TestCloseNotifierChanLeak(t *testing.T) {
3305 defer afterTest(t)
3306 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3307 for i := 0; i < 20; i++ {
3308 var output bytes.Buffer
3309 conn := &rwTestConn{
3310 Reader: bytes.NewReader(req),
3311 Writer: &output,
3312 closec: make(chan bool, 1),
3313 }
3314 ln := &oneConnListener{conn: conn}
3315 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3316
3317
3318
3319 _ = rw.(CloseNotifier).CloseNotify()
3320 })
3321 go Serve(ln, handler)
3322 <-conn.closec
3323 }
3324 }
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335 func TestHijackAfterCloseNotifier(t *testing.T) {
3336 defer afterTest(t)
3337 script := make(chan string, 2)
3338 script <- "closenotify"
3339 script <- "hijack"
3340 close(script)
3341 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3342 plan := <-script
3343 switch plan {
3344 default:
3345 panic("bogus plan; too many requests")
3346 case "closenotify":
3347 w.(CloseNotifier).CloseNotify()
3348 w.Header().Set("X-Addr", r.RemoteAddr)
3349 case "hijack":
3350 c, _, err := w.(Hijacker).Hijack()
3351 if err != nil {
3352 t.Errorf("Hijack in Handler: %v", err)
3353 return
3354 }
3355 if _, ok := c.(*net.TCPConn); !ok {
3356
3357
3358 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3359 }
3360 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3361 c.Close()
3362 return
3363 }
3364 }))
3365 defer ts.Close()
3366 res1, err := Get(ts.URL)
3367 if err != nil {
3368 log.Fatal(err)
3369 }
3370 res2, err := Get(ts.URL)
3371 if err != nil {
3372 log.Fatal(err)
3373 }
3374 addr1 := res1.Header.Get("X-Addr")
3375 addr2 := res2.Header.Get("X-Addr")
3376 if addr1 == "" || addr1 != addr2 {
3377 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3378 }
3379 }
3380
3381 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3382 setParallel(t)
3383 defer afterTest(t)
3384 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3385 bodyOkay := make(chan bool, 1)
3386 gotCloseNotify := make(chan bool, 1)
3387 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3388 defer close(bodyOkay)
3389
3390 reqBody := r.Body
3391 r.Body = nil
3392
3393 gone := w.(CloseNotifier).CloseNotify()
3394 slurp, err := io.ReadAll(reqBody)
3395 if err != nil {
3396 t.Errorf("Body read: %v", err)
3397 return
3398 }
3399 if len(slurp) != len(requestBody) {
3400 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3401 return
3402 }
3403 if !bytes.Equal(slurp, requestBody) {
3404 t.Error("Backend read wrong request body.")
3405 return
3406 }
3407 bodyOkay <- true
3408 select {
3409 case <-gone:
3410 gotCloseNotify <- true
3411 case <-time.After(5 * time.Second):
3412 gotCloseNotify <- false
3413 }
3414 }))
3415 defer ts.Close()
3416
3417 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3418 if err != nil {
3419 t.Fatal(err)
3420 }
3421 defer conn.Close()
3422
3423 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3424 len(requestBody), requestBody)
3425 if !<-bodyOkay {
3426
3427 return
3428 }
3429 conn.Close()
3430 if !<-gotCloseNotify {
3431 t.Error("timeout waiting for CloseNotify")
3432 }
3433 }
3434
3435 func TestOptions(t *testing.T) {
3436 uric := make(chan string, 2)
3437 mux := NewServeMux()
3438 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3439 uric <- r.RequestURI
3440 })
3441 ts := httptest.NewServer(mux)
3442 defer ts.Close()
3443
3444 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3445 if err != nil {
3446 t.Fatal(err)
3447 }
3448 defer conn.Close()
3449
3450
3451 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3452 if err != nil {
3453 t.Fatal(err)
3454 }
3455 br := bufio.NewReader(conn)
3456 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3457 if err != nil {
3458 t.Fatal(err)
3459 }
3460 if res.StatusCode != 200 {
3461 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3462 }
3463
3464
3465 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3466 if err != nil {
3467 t.Fatal(err)
3468 }
3469 res, err = ReadResponse(br, &Request{Method: "GET"})
3470 if err != nil {
3471 t.Fatal(err)
3472 }
3473 if res.StatusCode != 400 {
3474 t.Errorf("Got non-400 response to GET *: %#v", res)
3475 }
3476
3477 res, err = Get(ts.URL + "/second")
3478 if err != nil {
3479 t.Fatal(err)
3480 }
3481 res.Body.Close()
3482 if got := <-uric; got != "/second" {
3483 t.Errorf("Handler saw request for %q; want /second", got)
3484 }
3485 }
3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
3496 func TestHeaderToWire(t *testing.T) {
3497 tests := []struct {
3498 name string
3499 handler func(ResponseWriter, *Request)
3500 check func(got, logs string) error
3501 }{
3502 {
3503 name: "write without Header",
3504 handler: func(rw ResponseWriter, r *Request) {
3505 rw.Write([]byte("hello world"))
3506 },
3507 check: func(got, logs string) error {
3508 if !strings.Contains(got, "Content-Length:") {
3509 return errors.New("no content-length")
3510 }
3511 if !strings.Contains(got, "Content-Type: text/plain") {
3512 return errors.New("no content-type")
3513 }
3514 return nil
3515 },
3516 },
3517 {
3518 name: "Header mutation before write",
3519 handler: func(rw ResponseWriter, r *Request) {
3520 h := rw.Header()
3521 h.Set("Content-Type", "some/type")
3522 rw.Write([]byte("hello world"))
3523 h.Set("Too-Late", "bogus")
3524 },
3525 check: func(got, logs string) error {
3526 if !strings.Contains(got, "Content-Length:") {
3527 return errors.New("no content-length")
3528 }
3529 if !strings.Contains(got, "Content-Type: some/type") {
3530 return errors.New("wrong content-type")
3531 }
3532 if strings.Contains(got, "Too-Late") {
3533 return errors.New("don't want too-late header")
3534 }
3535 return nil
3536 },
3537 },
3538 {
3539 name: "write then useless Header mutation",
3540 handler: func(rw ResponseWriter, r *Request) {
3541 rw.Write([]byte("hello world"))
3542 rw.Header().Set("Too-Late", "Write already wrote headers")
3543 },
3544 check: func(got, logs string) error {
3545 if strings.Contains(got, "Too-Late") {
3546 return errors.New("header appeared from after WriteHeader")
3547 }
3548 return nil
3549 },
3550 },
3551 {
3552 name: "flush then write",
3553 handler: func(rw ResponseWriter, r *Request) {
3554 rw.(Flusher).Flush()
3555 rw.Write([]byte("post-flush"))
3556 rw.Header().Set("Too-Late", "Write already wrote headers")
3557 },
3558 check: func(got, logs string) error {
3559 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3560 return errors.New("not chunked")
3561 }
3562 if strings.Contains(got, "Too-Late") {
3563 return errors.New("header appeared from after WriteHeader")
3564 }
3565 return nil
3566 },
3567 },
3568 {
3569 name: "header then flush",
3570 handler: func(rw ResponseWriter, r *Request) {
3571 rw.Header().Set("Content-Type", "some/type")
3572 rw.(Flusher).Flush()
3573 rw.Write([]byte("post-flush"))
3574 rw.Header().Set("Too-Late", "Write already wrote headers")
3575 },
3576 check: func(got, logs string) error {
3577 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3578 return errors.New("not chunked")
3579 }
3580 if strings.Contains(got, "Too-Late") {
3581 return errors.New("header appeared from after WriteHeader")
3582 }
3583 if !strings.Contains(got, "Content-Type: some/type") {
3584 return errors.New("wrong content-type")
3585 }
3586 return nil
3587 },
3588 },
3589 {
3590 name: "sniff-on-first-write content-type",
3591 handler: func(rw ResponseWriter, r *Request) {
3592 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3593 rw.Header().Set("Content-Type", "x/wrong")
3594 },
3595 check: func(got, logs string) error {
3596 if !strings.Contains(got, "Content-Type: text/html") {
3597 return errors.New("wrong content-type; want html")
3598 }
3599 return nil
3600 },
3601 },
3602 {
3603 name: "explicit content-type wins",
3604 handler: func(rw ResponseWriter, r *Request) {
3605 rw.Header().Set("Content-Type", "some/type")
3606 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3607 },
3608 check: func(got, logs string) error {
3609 if !strings.Contains(got, "Content-Type: some/type") {
3610 return errors.New("wrong content-type; want html")
3611 }
3612 return nil
3613 },
3614 },
3615 {
3616 name: "empty handler",
3617 handler: func(rw ResponseWriter, r *Request) {
3618 },
3619 check: func(got, logs string) error {
3620 if !strings.Contains(got, "Content-Length: 0") {
3621 return errors.New("want 0 content-length")
3622 }
3623 return nil
3624 },
3625 },
3626 {
3627 name: "only Header, no write",
3628 handler: func(rw ResponseWriter, r *Request) {
3629 rw.Header().Set("Some-Header", "some-value")
3630 },
3631 check: func(got, logs string) error {
3632 if !strings.Contains(got, "Some-Header") {
3633 return errors.New("didn't get header")
3634 }
3635 return nil
3636 },
3637 },
3638 {
3639 name: "WriteHeader call",
3640 handler: func(rw ResponseWriter, r *Request) {
3641 rw.WriteHeader(404)
3642 rw.Header().Set("Too-Late", "some-value")
3643 },
3644 check: func(got, logs string) error {
3645 if !strings.Contains(got, "404") {
3646 return errors.New("wrong status")
3647 }
3648 if strings.Contains(got, "Too-Late") {
3649 return errors.New("shouldn't have seen Too-Late")
3650 }
3651 return nil
3652 },
3653 },
3654 }
3655 for _, tc := range tests {
3656 ht := newHandlerTest(HandlerFunc(tc.handler))
3657 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
3658 logs := ht.logbuf.String()
3659 if err := tc.check(got, logs); err != nil {
3660 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
3661 }
3662 }
3663 }
3664
3665 type errorListener struct {
3666 errs []error
3667 }
3668
3669 func (l *errorListener) Accept() (c net.Conn, err error) {
3670 if len(l.errs) == 0 {
3671 return nil, io.EOF
3672 }
3673 err = l.errs[0]
3674 l.errs = l.errs[1:]
3675 return
3676 }
3677
3678 func (l *errorListener) Close() error {
3679 return nil
3680 }
3681
3682 func (l *errorListener) Addr() net.Addr {
3683 return dummyAddr("test-address")
3684 }
3685
3686 func TestAcceptMaxFds(t *testing.T) {
3687 setParallel(t)
3688
3689 ln := &errorListener{[]error{
3690 &net.OpError{
3691 Op: "accept",
3692 Err: syscall.EMFILE,
3693 }}}
3694 server := &Server{
3695 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
3696 ErrorLog: log.New(io.Discard, "", 0),
3697 }
3698 err := server.Serve(ln)
3699 if err != io.EOF {
3700 t.Errorf("got error %v, want EOF", err)
3701 }
3702 }
3703
3704 func TestWriteAfterHijack(t *testing.T) {
3705 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3706 var buf bytes.Buffer
3707 wrotec := make(chan bool, 1)
3708 conn := &rwTestConn{
3709 Reader: bytes.NewReader(req),
3710 Writer: &buf,
3711 closec: make(chan bool, 1),
3712 }
3713 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3714 conn, bufrw, err := rw.(Hijacker).Hijack()
3715 if err != nil {
3716 t.Error(err)
3717 return
3718 }
3719 go func() {
3720 bufrw.Write([]byte("[hijack-to-bufw]"))
3721 bufrw.Flush()
3722 conn.Write([]byte("[hijack-to-conn]"))
3723 conn.Close()
3724 wrotec <- true
3725 }()
3726 })
3727 ln := &oneConnListener{conn: conn}
3728 go Serve(ln, handler)
3729 <-conn.closec
3730 <-wrotec
3731 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
3732 t.Errorf("wrote %q; want %q", g, w)
3733 }
3734 }
3735
3736 func TestDoubleHijack(t *testing.T) {
3737 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3738 var buf bytes.Buffer
3739 conn := &rwTestConn{
3740 Reader: bytes.NewReader(req),
3741 Writer: &buf,
3742 closec: make(chan bool, 1),
3743 }
3744 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3745 conn, _, err := rw.(Hijacker).Hijack()
3746 if err != nil {
3747 t.Error(err)
3748 return
3749 }
3750 _, _, err = rw.(Hijacker).Hijack()
3751 if err == nil {
3752 t.Errorf("got err = nil; want err != nil")
3753 }
3754 conn.Close()
3755 })
3756 ln := &oneConnListener{conn: conn}
3757 go Serve(ln, handler)
3758 <-conn.closec
3759 }
3760
3761
3762
3763
3764
3765
3766
3767 func TestHTTP10ConnectionHeader(t *testing.T) {
3768 defer afterTest(t)
3769
3770 mux := NewServeMux()
3771 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
3772 ts := httptest.NewServer(mux)
3773 defer ts.Close()
3774
3775
3776 tests := []struct {
3777 req string
3778 expect []string
3779 }{
3780 {
3781 req: "GET / HTTP/1.0\r\n\r\n",
3782 expect: nil,
3783 },
3784 {
3785 req: "OPTIONS * HTTP/1.0\r\n\r\n",
3786 expect: nil,
3787 },
3788 {
3789 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
3790 expect: []string{"keep-alive"},
3791 },
3792 }
3793
3794 for _, tt := range tests {
3795 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3796 if err != nil {
3797 t.Fatal("dial err:", err)
3798 }
3799
3800 _, err = fmt.Fprint(conn, tt.req)
3801 if err != nil {
3802 t.Fatal("conn write err:", err)
3803 }
3804
3805 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
3806 if err != nil {
3807 t.Fatal("ReadResponse err:", err)
3808 }
3809 conn.Close()
3810 resp.Body.Close()
3811
3812 got := resp.Header["Connection"]
3813 if !reflect.DeepEqual(got, tt.expect) {
3814 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
3815 }
3816 }
3817 }
3818
3819
3820 func TestServerReaderFromOrder_h1(t *testing.T) { testServerReaderFromOrder(t, h1Mode) }
3821 func TestServerReaderFromOrder_h2(t *testing.T) { testServerReaderFromOrder(t, h2Mode) }
3822 func testServerReaderFromOrder(t *testing.T, h2 bool) {
3823 setParallel(t)
3824 defer afterTest(t)
3825 pr, pw := io.Pipe()
3826 const size = 3 << 20
3827 cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
3828 rw.Header().Set("Content-Type", "text/plain")
3829 done := make(chan bool)
3830 go func() {
3831 io.Copy(rw, pr)
3832 close(done)
3833 }()
3834 time.Sleep(25 * time.Millisecond)
3835 n, err := io.Copy(io.Discard, req.Body)
3836 if err != nil {
3837 t.Errorf("handler Copy: %v", err)
3838 return
3839 }
3840 if n != size {
3841 t.Errorf("handler Copy = %d; want %d", n, size)
3842 }
3843 pw.Write([]byte("hi"))
3844 pw.Close()
3845 <-done
3846 }))
3847 defer cst.close()
3848
3849 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
3850 if err != nil {
3851 t.Fatal(err)
3852 }
3853 res, err := cst.c.Do(req)
3854 if err != nil {
3855 t.Fatal(err)
3856 }
3857 all, err := io.ReadAll(res.Body)
3858 if err != nil {
3859 t.Fatal(err)
3860 }
3861 res.Body.Close()
3862 if string(all) != "hi" {
3863 t.Errorf("Body = %q; want hi", all)
3864 }
3865 }
3866
3867
3868 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
3869 for _, code := range []int{StatusNotModified, StatusNoContent, StatusContinue} {
3870 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
3871 if r.URL.Path == "/header" {
3872 w.Header().Set("Content-Length", "123")
3873 }
3874 w.WriteHeader(code)
3875 if r.URL.Path == "/more" {
3876 w.Write([]byte("stuff"))
3877 }
3878 }))
3879 for _, req := range []string{
3880 "GET / HTTP/1.0",
3881 "GET /header HTTP/1.0",
3882 "GET /more HTTP/1.0",
3883 "GET / HTTP/1.1\nHost: foo",
3884 "GET /header HTTP/1.1\nHost: foo",
3885 "GET /more HTTP/1.1\nHost: foo",
3886 } {
3887 got := ht.rawResponse(req)
3888 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
3889 if !strings.Contains(got, wantStatus) {
3890 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
3891 } else if strings.Contains(got, "Content-Length") {
3892 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
3893 } else if strings.Contains(got, "stuff") {
3894 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
3895 }
3896 }
3897 }
3898 }
3899
3900 func TestContentTypeOkayOn204(t *testing.T) {
3901 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
3902 w.Header().Set("Content-Length", "123")
3903 w.Header().Set("Content-Type", "foo/bar")
3904 w.WriteHeader(204)
3905 }))
3906 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
3907 if !strings.Contains(got, "Content-Type: foo/bar") {
3908 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
3909 }
3910 if strings.Contains(got, "Content-Length: 123") {
3911 t.Errorf("Response = %q; don't want a Content-Length", got)
3912 }
3913 }
3914
3915
3916
3917
3918
3919
3920
3921 func TestTransportAndServerSharedBodyRace_h1(t *testing.T) {
3922 testTransportAndServerSharedBodyRace(t, h1Mode)
3923 }
3924 func TestTransportAndServerSharedBodyRace_h2(t *testing.T) {
3925 testTransportAndServerSharedBodyRace(t, h2Mode)
3926 }
3927 func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) {
3928 setParallel(t)
3929 defer afterTest(t)
3930
3931 const bodySize = 1 << 20
3932
3933
3934
3935
3936
3937 errorf := func(format string, args ...any) {
3938 v := fmt.Sprintf(format, args...)
3939 println(v)
3940 t.Error(v)
3941 }
3942
3943 unblockBackend := make(chan bool)
3944 backend := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
3945 gone := rw.(CloseNotifier).CloseNotify()
3946 didCopy := make(chan any)
3947 go func() {
3948 n, err := io.CopyN(rw, req.Body, bodySize)
3949 didCopy <- []any{n, err}
3950 }()
3951 isGone := false
3952 Loop:
3953 for {
3954 select {
3955 case <-didCopy:
3956 break Loop
3957 case <-gone:
3958 isGone = true
3959 case <-time.After(time.Second):
3960 println("1 second passes in backend, proxygone=", isGone)
3961 }
3962 }
3963 <-unblockBackend
3964 }))
3965 var quitTimer *time.Timer
3966 defer func() { quitTimer.Stop() }()
3967 defer backend.close()
3968
3969 backendRespc := make(chan *Response, 1)
3970 var proxy *clientServerTest
3971 proxy = newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
3972 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
3973 req2.ContentLength = bodySize
3974 cancel := make(chan struct{})
3975 req2.Cancel = cancel
3976
3977 bresp, err := proxy.c.Do(req2)
3978 if err != nil {
3979 errorf("Proxy outbound request: %v", err)
3980 return
3981 }
3982 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
3983 if err != nil {
3984 errorf("Proxy copy error: %v", err)
3985 return
3986 }
3987 backendRespc <- bresp
3988
3989
3990
3991 if h2 {
3992 close(cancel)
3993 } else {
3994 proxy.c.Transport.(*Transport).CancelRequest(req2)
3995 }
3996 rw.Write([]byte("OK"))
3997 }))
3998 defer proxy.close()
3999 defer func() {
4000
4001
4002
4003
4004
4005 quitTimer = time.AfterFunc(7*time.Second, func() {
4006 debug.SetTraceback("ALL")
4007 stacks := make([]byte, 1<<20)
4008 stacks = stacks[:runtime.Stack(stacks, true)]
4009 fmt.Fprintf(os.Stderr, "%s", stacks)
4010 log.Fatalf("Timeout.")
4011 })
4012 }()
4013
4014 defer close(unblockBackend)
4015 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
4016 res, err := proxy.c.Do(req)
4017 if err != nil {
4018 t.Fatalf("Original request: %v", err)
4019 }
4020
4021
4022 res.Body.Close()
4023 select {
4024 case res := <-backendRespc:
4025 res.Body.Close()
4026 default:
4027
4028 }
4029 }
4030
4031
4032
4033
4034 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
4035 if testing.Short() {
4036 t.Skip("skipping in -short mode")
4037 }
4038 defer afterTest(t)
4039
4040 readErrCh := make(chan error, 1)
4041 errCh := make(chan error, 2)
4042
4043 server := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
4044 go func(body io.Reader) {
4045 _, err := body.Read(make([]byte, 100))
4046 readErrCh <- err
4047 }(req.Body)
4048 time.Sleep(500 * time.Millisecond)
4049 }))
4050 defer server.Close()
4051
4052 closeConn := make(chan bool)
4053 defer close(closeConn)
4054 go func() {
4055 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4056 if err != nil {
4057 errCh <- err
4058 return
4059 }
4060 defer conn.Close()
4061 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4062 if err != nil {
4063 errCh <- err
4064 return
4065 }
4066
4067
4068 <-closeConn
4069 }()
4070 select {
4071 case err := <-readErrCh:
4072 if err == nil {
4073 t.Error("Read was nil. Expected error.")
4074 }
4075 case err := <-errCh:
4076 t.Error(err)
4077 case <-time.After(5 * time.Second):
4078 t.Error("timeout")
4079 }
4080 }
4081
4082
4083 func TestResponseWriterWriteString(t *testing.T) {
4084 okc := make(chan bool, 1)
4085 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4086 _, ok := w.(io.StringWriter)
4087 okc <- ok
4088 }))
4089 ht.rawResponse("GET / HTTP/1.0")
4090 select {
4091 case ok := <-okc:
4092 if !ok {
4093 t.Error("ResponseWriter did not implement io.StringWriter")
4094 }
4095 default:
4096 t.Error("handler was never called")
4097 }
4098 }
4099
4100 func TestAppendTime(t *testing.T) {
4101 var b [len(TimeFormat)]byte
4102 t1 := time.Date(2013, 9, 21, 15, 41, 0, 0, time.FixedZone("CEST", 2*60*60))
4103 res := ExportAppendTime(b[:0], t1)
4104 t2, err := ParseTime(string(res))
4105 if err != nil {
4106 t.Fatalf("Error parsing time: %s", err)
4107 }
4108 if !t1.Equal(t2) {
4109 t.Fatalf("Times differ; expected: %v, got %v (%s)", t1, t2, string(res))
4110 }
4111 }
4112
4113 func TestServerConnState(t *testing.T) {
4114 setParallel(t)
4115 defer afterTest(t)
4116 handler := map[string]func(w ResponseWriter, r *Request){
4117 "/": func(w ResponseWriter, r *Request) {
4118 fmt.Fprintf(w, "Hello.")
4119 },
4120 "/close": func(w ResponseWriter, r *Request) {
4121 w.Header().Set("Connection", "close")
4122 fmt.Fprintf(w, "Hello.")
4123 },
4124 "/hijack": func(w ResponseWriter, r *Request) {
4125 c, _, _ := w.(Hijacker).Hijack()
4126 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4127 c.Close()
4128 },
4129 "/hijack-panic": func(w ResponseWriter, r *Request) {
4130 c, _, _ := w.(Hijacker).Hijack()
4131 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4132 c.Close()
4133 panic("intentional panic")
4134 },
4135 }
4136
4137
4138 type stateLog struct {
4139 active net.Conn
4140 got []ConnState
4141 want []ConnState
4142 complete chan<- struct{}
4143 }
4144 activeLog := make(chan *stateLog, 1)
4145
4146
4147
4148
4149 wantLog := func(doRequests func(), want ...ConnState) {
4150 t.Helper()
4151 complete := make(chan struct{})
4152 activeLog <- &stateLog{want: want, complete: complete}
4153
4154 doRequests()
4155
4156 stateDelay := 5 * time.Second
4157 if deadline, ok := t.Deadline(); ok {
4158
4159
4160
4161
4162 const arbitraryCleanupMargin = 1 * time.Second
4163 stateDelay = time.Until(deadline) - arbitraryCleanupMargin
4164 }
4165 timer := time.NewTimer(stateDelay)
4166 select {
4167 case <-timer.C:
4168 t.Errorf("Timed out after %v waiting for connection to change state.", stateDelay)
4169 case <-complete:
4170 timer.Stop()
4171 }
4172 sl := <-activeLog
4173 if !reflect.DeepEqual(sl.got, sl.want) {
4174 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4175 }
4176
4177
4178
4179 }
4180
4181 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4182 handler[r.URL.Path](w, r)
4183 }))
4184 defer func() {
4185 activeLog <- &stateLog{}
4186 ts.Close()
4187 }()
4188
4189 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4190 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4191 if c == nil {
4192 t.Errorf("nil conn seen in state %s", state)
4193 return
4194 }
4195 sl := <-activeLog
4196 if sl.active == nil && state == StateNew {
4197 sl.active = c
4198 } else if sl.active != c {
4199 t.Errorf("unexpected conn in state %s", state)
4200 activeLog <- sl
4201 return
4202 }
4203 sl.got = append(sl.got, state)
4204 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) {
4205 close(sl.complete)
4206 sl.complete = nil
4207 }
4208 activeLog <- sl
4209 }
4210
4211 ts.Start()
4212 c := ts.Client()
4213
4214 mustGet := func(url string, headers ...string) {
4215 t.Helper()
4216 req, err := NewRequest("GET", url, nil)
4217 if err != nil {
4218 t.Fatal(err)
4219 }
4220 for len(headers) > 0 {
4221 req.Header.Add(headers[0], headers[1])
4222 headers = headers[2:]
4223 }
4224 res, err := c.Do(req)
4225 if err != nil {
4226 t.Errorf("Error fetching %s: %v", url, err)
4227 return
4228 }
4229 _, err = io.ReadAll(res.Body)
4230 defer res.Body.Close()
4231 if err != nil {
4232 t.Errorf("Error reading %s: %v", url, err)
4233 }
4234 }
4235
4236 wantLog(func() {
4237 mustGet(ts.URL + "/")
4238 mustGet(ts.URL + "/close")
4239 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4240
4241 wantLog(func() {
4242 mustGet(ts.URL + "/")
4243 mustGet(ts.URL+"/", "Connection", "close")
4244 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4245
4246 wantLog(func() {
4247 mustGet(ts.URL + "/hijack")
4248 }, StateNew, StateActive, StateHijacked)
4249
4250 wantLog(func() {
4251 mustGet(ts.URL + "/hijack-panic")
4252 }, StateNew, StateActive, StateHijacked)
4253
4254 wantLog(func() {
4255 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4256 if err != nil {
4257 t.Fatal(err)
4258 }
4259 c.Close()
4260 }, StateNew, StateClosed)
4261
4262 wantLog(func() {
4263 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4264 if err != nil {
4265 t.Fatal(err)
4266 }
4267 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4268 t.Fatal(err)
4269 }
4270 c.Read(make([]byte, 1))
4271 c.Close()
4272 }, StateNew, StateActive, StateClosed)
4273
4274 wantLog(func() {
4275 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4276 if err != nil {
4277 t.Fatal(err)
4278 }
4279 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4280 t.Fatal(err)
4281 }
4282 res, err := ReadResponse(bufio.NewReader(c), nil)
4283 if err != nil {
4284 t.Fatal(err)
4285 }
4286 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4287 t.Fatal(err)
4288 }
4289 c.Close()
4290 }, StateNew, StateActive, StateIdle, StateClosed)
4291 }
4292
4293 func TestServerKeepAlivesEnabled(t *testing.T) {
4294 defer afterTest(t)
4295 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
4296 ts.Config.SetKeepAlivesEnabled(false)
4297 ts.Start()
4298 defer ts.Close()
4299 res, err := Get(ts.URL)
4300 if err != nil {
4301 t.Fatal(err)
4302 }
4303 defer res.Body.Close()
4304 if !res.Close {
4305 t.Errorf("Body.Close == false; want true")
4306 }
4307 }
4308
4309
4310 func TestServerEmptyBodyRace_h1(t *testing.T) { testServerEmptyBodyRace(t, h1Mode) }
4311 func TestServerEmptyBodyRace_h2(t *testing.T) { testServerEmptyBodyRace(t, h2Mode) }
4312 func testServerEmptyBodyRace(t *testing.T, h2 bool) {
4313 setParallel(t)
4314 defer afterTest(t)
4315 var n int32
4316 cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
4317 atomic.AddInt32(&n, 1)
4318 }), optQuietLog)
4319 defer cst.close()
4320 var wg sync.WaitGroup
4321 const reqs = 20
4322 for i := 0; i < reqs; i++ {
4323 wg.Add(1)
4324 go func() {
4325 defer wg.Done()
4326 res, err := cst.c.Get(cst.ts.URL)
4327 if err != nil {
4328
4329
4330 time.Sleep(10 * time.Millisecond)
4331 res, err = cst.c.Get(cst.ts.URL)
4332 if err != nil {
4333 t.Error(err)
4334 return
4335 }
4336 }
4337 defer res.Body.Close()
4338 _, err = io.Copy(io.Discard, res.Body)
4339 if err != nil {
4340 t.Error(err)
4341 return
4342 }
4343 }()
4344 }
4345 wg.Wait()
4346 if got := atomic.LoadInt32(&n); got != reqs {
4347 t.Errorf("handler ran %d times; want %d", got, reqs)
4348 }
4349 }
4350
4351 func TestServerConnStateNew(t *testing.T) {
4352 sawNew := false
4353 srv := &Server{
4354 ConnState: func(c net.Conn, state ConnState) {
4355 if state == StateNew {
4356 sawNew = true
4357 }
4358 },
4359 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4360 }
4361 srv.Serve(&oneConnListener{
4362 conn: &rwTestConn{
4363 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4364 Writer: io.Discard,
4365 },
4366 })
4367 if !sawNew {
4368 t.Error("StateNew not seen")
4369 }
4370 }
4371
4372 type closeWriteTestConn struct {
4373 rwTestConn
4374 didCloseWrite bool
4375 }
4376
4377 func (c *closeWriteTestConn) CloseWrite() error {
4378 c.didCloseWrite = true
4379 return nil
4380 }
4381
4382 func TestCloseWrite(t *testing.T) {
4383 setParallel(t)
4384 var srv Server
4385 var testConn closeWriteTestConn
4386 c := ExportServerNewConn(&srv, &testConn)
4387 ExportCloseWriteAndWait(c)
4388 if !testConn.didCloseWrite {
4389 t.Error("didn't see CloseWrite call")
4390 }
4391 }
4392
4393
4394
4395
4396
4397
4398
4399
4400 func TestServerFlushAndHijack(t *testing.T) {
4401 defer afterTest(t)
4402 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4403 io.WriteString(w, "Hello, ")
4404 w.(Flusher).Flush()
4405 conn, buf, _ := w.(Hijacker).Hijack()
4406 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4407 if err := buf.Flush(); err != nil {
4408 t.Error(err)
4409 }
4410 if err := conn.Close(); err != nil {
4411 t.Error(err)
4412 }
4413 }))
4414 defer ts.Close()
4415 res, err := Get(ts.URL)
4416 if err != nil {
4417 t.Fatal(err)
4418 }
4419 defer res.Body.Close()
4420 all, err := io.ReadAll(res.Body)
4421 if err != nil {
4422 t.Fatal(err)
4423 }
4424 if want := "Hello, world!"; string(all) != want {
4425 t.Errorf("Got %q; want %q", all, want)
4426 }
4427 }
4428
4429
4430
4431
4432
4433
4434
4435 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4436 if testing.Short() {
4437 t.Skip("skipping in -short mode")
4438 }
4439 defer afterTest(t)
4440 const numReq = 3
4441 addrc := make(chan string, numReq)
4442 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4443 addrc <- r.RemoteAddr
4444 time.Sleep(500 * time.Millisecond)
4445 w.(Flusher).Flush()
4446 }))
4447 ts.Config.WriteTimeout = 250 * time.Millisecond
4448 ts.Start()
4449 defer ts.Close()
4450
4451 errc := make(chan error, numReq)
4452 go func() {
4453 defer close(errc)
4454 for i := 0; i < numReq; i++ {
4455 res, err := Get(ts.URL)
4456 if res != nil {
4457 res.Body.Close()
4458 }
4459 errc <- err
4460 }
4461 }()
4462
4463 timeout := time.NewTimer(numReq * 2 * time.Second)
4464 defer timeout.Stop()
4465 addrSeen := map[string]bool{}
4466 numOkay := 0
4467 for {
4468 select {
4469 case v := <-addrc:
4470 addrSeen[v] = true
4471 case err, ok := <-errc:
4472 if !ok {
4473 if len(addrSeen) != numReq {
4474 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4475 }
4476 if numOkay != 0 {
4477 t.Errorf("got %d successful client requests; want 0", numOkay)
4478 }
4479 return
4480 }
4481 if err == nil {
4482 numOkay++
4483 }
4484 case <-timeout.C:
4485 t.Fatal("timeout waiting for requests to complete")
4486 }
4487 }
4488 }
4489
4490
4491
4492 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4493 defer afterTest(t)
4494 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4495 w.Header().Set("Transfer-Encoding", "foo")
4496 io.WriteString(w, "<html>")
4497 }))
4498 defer ts.Close()
4499 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4500 if err != nil {
4501 t.Fatalf("Dial: %v", err)
4502 }
4503 defer c.Close()
4504 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4505 t.Fatal(err)
4506 }
4507 bs := bufio.NewScanner(c)
4508 var got bytes.Buffer
4509 for bs.Scan() {
4510 if strings.TrimSpace(bs.Text()) == "" {
4511 break
4512 }
4513 got.WriteString(bs.Text())
4514 got.WriteByte('\n')
4515 }
4516 if err := bs.Err(); err != nil {
4517 t.Fatal(err)
4518 }
4519 if strings.Contains(got.String(), "Content-Length") {
4520 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4521 }
4522 if strings.Contains(got.String(), "Content-Type") {
4523 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4524 }
4525 }
4526
4527
4528
4529 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4530 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4531 "\r\n\r\n" +
4532 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4533 var buf bytes.Buffer
4534 conn := &rwTestConn{
4535 Reader: bytes.NewReader(req),
4536 Writer: &buf,
4537 closec: make(chan bool, 1),
4538 }
4539 ln := &oneConnListener{conn: conn}
4540 numReq := 0
4541 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4542 numReq++
4543 }))
4544 <-conn.closec
4545 if numReq != 2 {
4546 t.Errorf("num requests = %d; want 2", numReq)
4547 t.Logf("Res: %s", buf.Bytes())
4548 }
4549 }
4550
4551 func TestIssue13893_Expect100(t *testing.T) {
4552
4553 req := reqBytes(`PUT /readbody HTTP/1.1
4554 User-Agent: PycURL/7.22.0
4555 Host: 127.0.0.1:9000
4556 Accept: */*
4557 Expect: 100-continue
4558 Content-Length: 10
4559
4560 HelloWorld
4561
4562 `)
4563 var buf bytes.Buffer
4564 conn := &rwTestConn{
4565 Reader: bytes.NewReader(req),
4566 Writer: &buf,
4567 closec: make(chan bool, 1),
4568 }
4569 ln := &oneConnListener{conn: conn}
4570 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4571 if _, ok := r.Header["Expect"]; !ok {
4572 t.Error("Expect header should not be filtered out")
4573 }
4574 }))
4575 <-conn.closec
4576 }
4577
4578 func TestIssue11549_Expect100(t *testing.T) {
4579 req := reqBytes(`PUT /readbody HTTP/1.1
4580 User-Agent: PycURL/7.22.0
4581 Host: 127.0.0.1:9000
4582 Accept: */*
4583 Expect: 100-continue
4584 Content-Length: 10
4585
4586 HelloWorldPUT /noreadbody HTTP/1.1
4587 User-Agent: PycURL/7.22.0
4588 Host: 127.0.0.1:9000
4589 Accept: */*
4590 Expect: 100-continue
4591 Content-Length: 10
4592
4593 GET /should-be-ignored HTTP/1.1
4594 Host: foo
4595
4596 `)
4597 var buf bytes.Buffer
4598 conn := &rwTestConn{
4599 Reader: bytes.NewReader(req),
4600 Writer: &buf,
4601 closec: make(chan bool, 1),
4602 }
4603 ln := &oneConnListener{conn: conn}
4604 numReq := 0
4605 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4606 numReq++
4607 if r.URL.Path == "/readbody" {
4608 io.ReadAll(r.Body)
4609 }
4610 io.WriteString(w, "Hello world!")
4611 }))
4612 <-conn.closec
4613 if numReq != 2 {
4614 t.Errorf("num requests = %d; want 2", numReq)
4615 }
4616 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4617 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4618 }
4619 }
4620
4621
4622
4623 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4624 setParallel(t)
4625 conn := &testConn{closec: make(chan bool)}
4626 conn.readBuf.Write([]byte(fmt.Sprintf(
4627 "POST / HTTP/1.1\r\n" +
4628 "Host: test\r\n" +
4629 "Content-Length: 9999999999\r\n" +
4630 "\r\n" + strings.Repeat("a", 1<<20))))
4631
4632 ls := &oneConnListener{conn}
4633 var inHandlerLen int
4634 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4635 inHandlerLen = conn.readBuf.Len()
4636 rw.WriteHeader(404)
4637 }))
4638 <-conn.closec
4639 afterHandlerLen := conn.readBuf.Len()
4640
4641 if afterHandlerLen != inHandlerLen {
4642 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4643 }
4644 }
4645
4646 func TestHandlerSetsBodyNil_h1(t *testing.T) { testHandlerSetsBodyNil(t, h1Mode) }
4647 func TestHandlerSetsBodyNil_h2(t *testing.T) { testHandlerSetsBodyNil(t, h2Mode) }
4648 func testHandlerSetsBodyNil(t *testing.T, h2 bool) {
4649 defer afterTest(t)
4650 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
4651 r.Body = nil
4652 fmt.Fprintf(w, "%v", r.RemoteAddr)
4653 }))
4654 defer cst.close()
4655 get := func() string {
4656 res, err := cst.c.Get(cst.ts.URL)
4657 if err != nil {
4658 t.Fatal(err)
4659 }
4660 defer res.Body.Close()
4661 slurp, err := io.ReadAll(res.Body)
4662 if err != nil {
4663 t.Fatal(err)
4664 }
4665 return string(slurp)
4666 }
4667 a, b := get(), get()
4668 if a != b {
4669 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
4670 }
4671 }
4672
4673
4674
4675 func TestServerValidatesHostHeader(t *testing.T) {
4676 tests := []struct {
4677 proto string
4678 host string
4679 want int
4680 }{
4681 {"HTTP/0.9", "", 505},
4682
4683 {"HTTP/1.1", "", 400},
4684 {"HTTP/1.1", "Host: \r\n", 200},
4685 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4686 {"HTTP/1.1", "Host: foo.com\r\n", 200},
4687 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
4688 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
4689 {"HTTP/1.1", "Host: ::1\r\n", 200},
4690 {"HTTP/1.1", "Host: [::1]\r\n", 200},
4691 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
4692 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
4693 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4694 {"HTTP/1.1", "Host: \x06\r\n", 400},
4695 {"HTTP/1.1", "Host: \xff\r\n", 400},
4696 {"HTTP/1.1", "Host: {\r\n", 400},
4697 {"HTTP/1.1", "Host: }\r\n", 400},
4698 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
4699
4700
4701
4702 {"HTTP/1.0", "", 200},
4703 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
4704 {"HTTP/1.0", "Host: \xff\r\n", 400},
4705
4706
4707 {"PRI * HTTP/2.0", "", 200},
4708
4709
4710 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
4711
4712
4713 {"PRI / HTTP/2.0", "", 505},
4714 {"GET / HTTP/2.0", "", 505},
4715 {"GET / HTTP/3.0", "", 505},
4716 }
4717 for _, tt := range tests {
4718 conn := &testConn{closec: make(chan bool, 1)}
4719 methodTarget := "GET / "
4720 if !strings.HasPrefix(tt.proto, "HTTP/") {
4721 methodTarget = ""
4722 }
4723 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
4724
4725 ln := &oneConnListener{conn}
4726 srv := Server{
4727 ErrorLog: quietLog,
4728 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4729 }
4730 go srv.Serve(ln)
4731 <-conn.closec
4732 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4733 if err != nil {
4734 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
4735 continue
4736 }
4737 if res.StatusCode != tt.want {
4738 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
4739 }
4740 }
4741 }
4742
4743 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
4744 const upgradeResponse = "upgrade here"
4745 defer afterTest(t)
4746 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4747 conn, br, err := w.(Hijacker).Hijack()
4748 if err != nil {
4749 t.Error(err)
4750 return
4751 }
4752 defer conn.Close()
4753 if r.Method != "PRI" || r.RequestURI != "*" {
4754 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
4755 return
4756 }
4757 if !r.Close {
4758 t.Errorf("Request.Close = true; want false")
4759 }
4760 const want = "SM\r\n\r\n"
4761 buf := make([]byte, len(want))
4762 n, err := io.ReadFull(br, buf)
4763 if err != nil || string(buf[:n]) != want {
4764 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
4765 return
4766 }
4767 io.WriteString(conn, upgradeResponse)
4768 }))
4769 defer ts.Close()
4770
4771 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4772 if err != nil {
4773 t.Fatalf("Dial: %v", err)
4774 }
4775 defer c.Close()
4776 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
4777 slurp, err := io.ReadAll(c)
4778 if err != nil {
4779 t.Fatal(err)
4780 }
4781 if string(slurp) != upgradeResponse {
4782 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
4783 }
4784 }
4785
4786
4787
4788 func TestServerValidatesHeaders(t *testing.T) {
4789 setParallel(t)
4790 tests := []struct {
4791 header string
4792 want int
4793 }{
4794 {"", 200},
4795 {"Foo: bar\r\n", 200},
4796 {"X-Foo: bar\r\n", 200},
4797 {"Foo: a space\r\n", 200},
4798
4799 {"A space: foo\r\n", 400},
4800 {"foo\xffbar: foo\r\n", 400},
4801 {"foo\x00bar: foo\r\n", 400},
4802 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
4803
4804
4805 {"Foo : bar\r\n", 400},
4806 {"Foo\t: bar\r\n", 400},
4807
4808 {"foo: foo foo\r\n", 200},
4809 {"foo: foo\tfoo\r\n", 200},
4810 {"foo: foo\x00foo\r\n", 400},
4811 {"foo: foo\x7ffoo\r\n", 400},
4812 {"foo: foo\xfffoo\r\n", 200},
4813 }
4814 for _, tt := range tests {
4815 conn := &testConn{closec: make(chan bool, 1)}
4816 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
4817
4818 ln := &oneConnListener{conn}
4819 srv := Server{
4820 ErrorLog: quietLog,
4821 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4822 }
4823 go srv.Serve(ln)
4824 <-conn.closec
4825 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4826 if err != nil {
4827 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
4828 continue
4829 }
4830 if res.StatusCode != tt.want {
4831 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
4832 }
4833 }
4834 }
4835
4836 func TestServerRequestContextCancel_ServeHTTPDone_h1(t *testing.T) {
4837 testServerRequestContextCancel_ServeHTTPDone(t, h1Mode)
4838 }
4839 func TestServerRequestContextCancel_ServeHTTPDone_h2(t *testing.T) {
4840 testServerRequestContextCancel_ServeHTTPDone(t, h2Mode)
4841 }
4842 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) {
4843 setParallel(t)
4844 defer afterTest(t)
4845 ctxc := make(chan context.Context, 1)
4846 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
4847 ctx := r.Context()
4848 select {
4849 case <-ctx.Done():
4850 t.Error("should not be Done in ServeHTTP")
4851 default:
4852 }
4853 ctxc <- ctx
4854 }))
4855 defer cst.close()
4856 res, err := cst.c.Get(cst.ts.URL)
4857 if err != nil {
4858 t.Fatal(err)
4859 }
4860 res.Body.Close()
4861 ctx := <-ctxc
4862 select {
4863 case <-ctx.Done():
4864 default:
4865 t.Error("context should be done after ServeHTTP completes")
4866 }
4867 }
4868
4869
4870
4871
4872
4873 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
4874 setParallel(t)
4875 defer afterTest(t)
4876 inHandler := make(chan struct{})
4877 handlerDone := make(chan struct{})
4878 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4879 close(inHandler)
4880 select {
4881 case <-r.Context().Done():
4882 case <-time.After(3 * time.Second):
4883 t.Errorf("timeout waiting for context to be done")
4884 }
4885 close(handlerDone)
4886 }))
4887 defer ts.Close()
4888 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4889 if err != nil {
4890 t.Fatal(err)
4891 }
4892 defer c.Close()
4893 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
4894 select {
4895 case <-inHandler:
4896 case <-time.After(3 * time.Second):
4897 t.Fatalf("timeout waiting to see ServeHTTP get called")
4898 }
4899 c.Close()
4900
4901 select {
4902 case <-handlerDone:
4903 case <-time.After(4 * time.Second):
4904 t.Fatalf("timeout waiting to see ServeHTTP exit")
4905 }
4906 }
4907
4908 func TestServerContext_ServerContextKey_h1(t *testing.T) {
4909 testServerContext_ServerContextKey(t, h1Mode)
4910 }
4911 func TestServerContext_ServerContextKey_h2(t *testing.T) {
4912 testServerContext_ServerContextKey(t, h2Mode)
4913 }
4914 func testServerContext_ServerContextKey(t *testing.T, h2 bool) {
4915 setParallel(t)
4916 defer afterTest(t)
4917 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
4918 ctx := r.Context()
4919 got := ctx.Value(ServerContextKey)
4920 if _, ok := got.(*Server); !ok {
4921 t.Errorf("context value = %T; want *http.Server", got)
4922 }
4923 }))
4924 defer cst.close()
4925 res, err := cst.c.Get(cst.ts.URL)
4926 if err != nil {
4927 t.Fatal(err)
4928 }
4929 res.Body.Close()
4930 }
4931
4932 func TestServerContext_LocalAddrContextKey_h1(t *testing.T) {
4933 testServerContext_LocalAddrContextKey(t, h1Mode)
4934 }
4935 func TestServerContext_LocalAddrContextKey_h2(t *testing.T) {
4936 testServerContext_LocalAddrContextKey(t, h2Mode)
4937 }
4938 func testServerContext_LocalAddrContextKey(t *testing.T, h2 bool) {
4939 setParallel(t)
4940 defer afterTest(t)
4941 ch := make(chan any, 1)
4942 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
4943 ch <- r.Context().Value(LocalAddrContextKey)
4944 }))
4945 defer cst.close()
4946 if _, err := cst.c.Head(cst.ts.URL); err != nil {
4947 t.Fatal(err)
4948 }
4949
4950 host := cst.ts.Listener.Addr().String()
4951 select {
4952 case got := <-ch:
4953 if addr, ok := got.(net.Addr); !ok {
4954 t.Errorf("local addr value = %T; want net.Addr", got)
4955 } else if fmt.Sprint(addr) != host {
4956 t.Errorf("local addr = %v; want %v", addr, host)
4957 }
4958 case <-time.After(5 * time.Second):
4959 t.Error("timed out")
4960 }
4961 }
4962
4963
4964 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
4965 setParallel(t)
4966 defer afterTest(t)
4967 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4968 w.Header().Set("Transfer-Encoding", "chunked")
4969 w.Write([]byte("hello"))
4970 }))
4971 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4972 const hdr = "Transfer-Encoding: chunked"
4973 if n := strings.Count(resp, hdr); n != 1 {
4974 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
4975 }
4976 }
4977
4978
4979 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
4980 setParallel(t)
4981 defer afterTest(t)
4982 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4983 w.Header().Set("Transfer-Encoding", "gzip")
4984 gz := gzip.NewWriter(w)
4985 gz.Write([]byte("hello"))
4986 gz.Close()
4987 }))
4988 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4989 for _, v := range []string{"gzip", "chunked"} {
4990 hdr := "Transfer-Encoding: " + v
4991 if n := strings.Count(resp, hdr); n != 1 {
4992 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
4993 }
4994 }
4995 }
4996
4997 func BenchmarkClientServer(b *testing.B) {
4998 b.ReportAllocs()
4999 b.StopTimer()
5000 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5001 fmt.Fprintf(rw, "Hello world.\n")
5002 }))
5003 defer ts.Close()
5004 b.StartTimer()
5005
5006 for i := 0; i < b.N; i++ {
5007 res, err := Get(ts.URL)
5008 if err != nil {
5009 b.Fatal("Get:", err)
5010 }
5011 all, err := io.ReadAll(res.Body)
5012 res.Body.Close()
5013 if err != nil {
5014 b.Fatal("ReadAll:", err)
5015 }
5016 body := string(all)
5017 if body != "Hello world.\n" {
5018 b.Fatal("Got body:", body)
5019 }
5020 }
5021
5022 b.StopTimer()
5023 }
5024
5025 func BenchmarkClientServerParallel4(b *testing.B) {
5026 benchmarkClientServerParallel(b, 4, false)
5027 }
5028
5029 func BenchmarkClientServerParallel64(b *testing.B) {
5030 benchmarkClientServerParallel(b, 64, false)
5031 }
5032
5033 func BenchmarkClientServerParallelTLS4(b *testing.B) {
5034 benchmarkClientServerParallel(b, 4, true)
5035 }
5036
5037 func BenchmarkClientServerParallelTLS64(b *testing.B) {
5038 benchmarkClientServerParallel(b, 64, true)
5039 }
5040
5041 func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) {
5042 b.ReportAllocs()
5043 ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5044 fmt.Fprintf(rw, "Hello world.\n")
5045 }))
5046 if useTLS {
5047 ts.StartTLS()
5048 } else {
5049 ts.Start()
5050 }
5051 defer ts.Close()
5052 b.ResetTimer()
5053 b.SetParallelism(parallelism)
5054 b.RunParallel(func(pb *testing.PB) {
5055 c := ts.Client()
5056 for pb.Next() {
5057 res, err := c.Get(ts.URL)
5058 if err != nil {
5059 b.Logf("Get: %v", err)
5060 continue
5061 }
5062 all, err := io.ReadAll(res.Body)
5063 res.Body.Close()
5064 if err != nil {
5065 b.Logf("ReadAll: %v", err)
5066 continue
5067 }
5068 body := string(all)
5069 if body != "Hello world.\n" {
5070 panic("Got body: " + body)
5071 }
5072 }
5073 })
5074 }
5075
5076
5077
5078
5079
5080
5081
5082
5083
5084 func BenchmarkServer(b *testing.B) {
5085 b.ReportAllocs()
5086
5087 if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" {
5088 n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N"))
5089 if err != nil {
5090 panic(err)
5091 }
5092 for i := 0; i < n; i++ {
5093 res, err := Get(url)
5094 if err != nil {
5095 log.Panicf("Get: %v", err)
5096 }
5097 all, err := io.ReadAll(res.Body)
5098 res.Body.Close()
5099 if err != nil {
5100 log.Panicf("ReadAll: %v", err)
5101 }
5102 body := string(all)
5103 if body != "Hello world.\n" {
5104 log.Panicf("Got body: %q", body)
5105 }
5106 }
5107 os.Exit(0)
5108 return
5109 }
5110
5111 var res = []byte("Hello world.\n")
5112 b.StopTimer()
5113 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5114 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5115 rw.Write(res)
5116 }))
5117 defer ts.Close()
5118 b.StartTimer()
5119
5120 cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkServer$")
5121 cmd.Env = append([]string{
5122 fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
5123 fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
5124 }, os.Environ()...)
5125 out, err := cmd.CombinedOutput()
5126 if err != nil {
5127 b.Errorf("Test failure: %v, with output: %s", err, out)
5128 }
5129 }
5130
5131
5132 func getNoBody(urlStr string) (*Response, error) {
5133 res, err := Get(urlStr)
5134 if err != nil {
5135 return nil, err
5136 }
5137 res.Body.Close()
5138 return res, nil
5139 }
5140
5141
5142
5143 func BenchmarkClient(b *testing.B) {
5144 b.ReportAllocs()
5145 b.StopTimer()
5146 defer afterTest(b)
5147
5148 var data = []byte("Hello world.\n")
5149 if server := os.Getenv("TEST_BENCH_SERVER"); server != "" {
5150
5151 port := os.Getenv("TEST_BENCH_SERVER_PORT")
5152 if port == "" {
5153 port = "0"
5154 }
5155 ln, err := net.Listen("tcp", "localhost:"+port)
5156 if err != nil {
5157 fmt.Fprintln(os.Stderr, err.Error())
5158 os.Exit(1)
5159 }
5160 fmt.Println(ln.Addr().String())
5161 HandleFunc("/", func(w ResponseWriter, r *Request) {
5162 r.ParseForm()
5163 if r.Form.Get("stop") != "" {
5164 os.Exit(0)
5165 }
5166 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5167 w.Write(data)
5168 })
5169 var srv Server
5170 log.Fatal(srv.Serve(ln))
5171 }
5172
5173
5174 cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkClient$")
5175 cmd.Env = append(os.Environ(), "TEST_BENCH_SERVER=yes")
5176 cmd.Stderr = os.Stderr
5177 stdout, err := cmd.StdoutPipe()
5178 if err != nil {
5179 b.Fatal(err)
5180 }
5181 if err := cmd.Start(); err != nil {
5182 b.Fatalf("subprocess failed to start: %v", err)
5183 }
5184 defer cmd.Process.Kill()
5185
5186
5187
5188 timer := time.AfterFunc(10*time.Second, func() {
5189 cmd.Process.Kill()
5190 })
5191 defer timer.Stop()
5192 bs := bufio.NewScanner(stdout)
5193 if !bs.Scan() {
5194 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5195 }
5196 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5197 timer.Stop()
5198 if _, err := getNoBody(url); err != nil {
5199 b.Fatalf("initial probe of child process failed: %v", err)
5200 }
5201
5202 done := make(chan error)
5203 stop := make(chan struct{})
5204 defer close(stop)
5205 go func() {
5206 select {
5207 case <-stop:
5208 return
5209 case done <- cmd.Wait():
5210 }
5211 }()
5212
5213
5214 b.StartTimer()
5215 for i := 0; i < b.N; i++ {
5216 res, err := Get(url)
5217 if err != nil {
5218 b.Fatalf("Get: %v", err)
5219 }
5220 body, err := io.ReadAll(res.Body)
5221 res.Body.Close()
5222 if err != nil {
5223 b.Fatalf("ReadAll: %v", err)
5224 }
5225 if !bytes.Equal(body, data) {
5226 b.Fatalf("Got body: %q", body)
5227 }
5228 }
5229 b.StopTimer()
5230
5231
5232 getNoBody(url + "?stop=yes")
5233 select {
5234 case err := <-done:
5235 if err != nil {
5236 b.Fatalf("subprocess failed: %v", err)
5237 }
5238 case <-time.After(5 * time.Second):
5239 b.Fatalf("subprocess did not stop")
5240 }
5241 }
5242
5243 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5244 b.ReportAllocs()
5245 req := reqBytes(`GET / HTTP/1.0
5246 Host: golang.org
5247 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5248 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5249 Accept-Encoding: gzip,deflate,sdch
5250 Accept-Language: en-US,en;q=0.8
5251 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5252 `)
5253 res := []byte("Hello world!\n")
5254
5255 conn := &testConn{
5256
5257
5258 closec: make(chan bool, 1),
5259 }
5260 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5261 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5262 rw.Write(res)
5263 })
5264 ln := new(oneConnListener)
5265 for i := 0; i < b.N; i++ {
5266 conn.readBuf.Reset()
5267 conn.writeBuf.Reset()
5268 conn.readBuf.Write(req)
5269 ln.conn = conn
5270 Serve(ln, handler)
5271 <-conn.closec
5272 }
5273 }
5274
5275
5276 type repeatReader struct {
5277 content []byte
5278 count int
5279 off int
5280 }
5281
5282 func (r *repeatReader) Read(p []byte) (n int, err error) {
5283 if r.count <= 0 {
5284 return 0, io.EOF
5285 }
5286 n = copy(p, r.content[r.off:])
5287 r.off += n
5288 if r.off == len(r.content) {
5289 r.count--
5290 r.off = 0
5291 }
5292 return
5293 }
5294
5295 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5296 b.ReportAllocs()
5297
5298 req := reqBytes(`GET / HTTP/1.1
5299 Host: golang.org
5300 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5301 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5302 Accept-Encoding: gzip,deflate,sdch
5303 Accept-Language: en-US,en;q=0.8
5304 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5305 `)
5306 res := []byte("Hello world!\n")
5307
5308 conn := &rwTestConn{
5309 Reader: &repeatReader{content: req, count: b.N},
5310 Writer: io.Discard,
5311 closec: make(chan bool, 1),
5312 }
5313 handled := 0
5314 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5315 handled++
5316 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5317 rw.Write(res)
5318 })
5319 ln := &oneConnListener{conn: conn}
5320 go Serve(ln, handler)
5321 <-conn.closec
5322 if b.N != handled {
5323 b.Errorf("b.N=%d but handled %d", b.N, handled)
5324 }
5325 }
5326
5327
5328
5329 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5330 b.ReportAllocs()
5331
5332 req := reqBytes(`GET / HTTP/1.1
5333 Host: golang.org
5334 `)
5335 res := []byte("Hello world!\n")
5336
5337 conn := &rwTestConn{
5338 Reader: &repeatReader{content: req, count: b.N},
5339 Writer: io.Discard,
5340 closec: make(chan bool, 1),
5341 }
5342 handled := 0
5343 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5344 handled++
5345 rw.Write(res)
5346 })
5347 ln := &oneConnListener{conn: conn}
5348 go Serve(ln, handler)
5349 <-conn.closec
5350 if b.N != handled {
5351 b.Errorf("b.N=%d but handled %d", b.N, handled)
5352 }
5353 }
5354
5355 const someResponse = "<html>some response</html>"
5356
5357
5358 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5359
5360
5361 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5362 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5363 w.Header().Set("Content-Type", "text/html")
5364 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5365 w.Write(response)
5366 }))
5367 }
5368
5369
5370 func BenchmarkServerHandlerNoLen(b *testing.B) {
5371 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5372 w.Header().Set("Content-Type", "text/html")
5373 w.Write(response)
5374 }))
5375 }
5376
5377
5378 func BenchmarkServerHandlerNoType(b *testing.B) {
5379 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5380 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5381 w.Write(response)
5382 }))
5383 }
5384
5385
5386 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5387 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5388 w.Write(response)
5389 }))
5390 }
5391
5392 func benchmarkHandler(b *testing.B, h Handler) {
5393 b.ReportAllocs()
5394 req := reqBytes(`GET / HTTP/1.1
5395 Host: golang.org
5396 `)
5397 conn := &rwTestConn{
5398 Reader: &repeatReader{content: req, count: b.N},
5399 Writer: io.Discard,
5400 closec: make(chan bool, 1),
5401 }
5402 handled := 0
5403 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5404 handled++
5405 h.ServeHTTP(rw, r)
5406 })
5407 ln := &oneConnListener{conn: conn}
5408 go Serve(ln, handler)
5409 <-conn.closec
5410 if b.N != handled {
5411 b.Errorf("b.N=%d but handled %d", b.N, handled)
5412 }
5413 }
5414
5415 func BenchmarkServerHijack(b *testing.B) {
5416 b.ReportAllocs()
5417 req := reqBytes(`GET / HTTP/1.1
5418 Host: golang.org
5419 `)
5420 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5421 conn, _, err := w.(Hijacker).Hijack()
5422 if err != nil {
5423 panic(err)
5424 }
5425 conn.Close()
5426 })
5427 conn := &rwTestConn{
5428 Writer: io.Discard,
5429 closec: make(chan bool, 1),
5430 }
5431 ln := &oneConnListener{conn: conn}
5432 for i := 0; i < b.N; i++ {
5433 conn.Reader = bytes.NewReader(req)
5434 ln.conn = conn
5435 Serve(ln, h)
5436 <-conn.closec
5437 }
5438 }
5439
5440 func BenchmarkCloseNotifier(b *testing.B) {
5441 b.ReportAllocs()
5442 b.StopTimer()
5443 sawClose := make(chan bool)
5444 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
5445 <-rw.(CloseNotifier).CloseNotify()
5446 sawClose <- true
5447 }))
5448 defer ts.Close()
5449 tot := time.NewTimer(5 * time.Second)
5450 defer tot.Stop()
5451 b.StartTimer()
5452 for i := 0; i < b.N; i++ {
5453 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5454 if err != nil {
5455 b.Fatalf("error dialing: %v", err)
5456 }
5457 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5458 if err != nil {
5459 b.Fatal(err)
5460 }
5461 conn.Close()
5462 tot.Reset(5 * time.Second)
5463 select {
5464 case <-sawClose:
5465 case <-tot.C:
5466 b.Fatal("timeout")
5467 }
5468 }
5469 b.StopTimer()
5470 }
5471
5472
5473 func TestConcurrentServerServe(t *testing.T) {
5474 setParallel(t)
5475 for i := 0; i < 100; i++ {
5476 ln1 := &oneConnListener{conn: nil}
5477 ln2 := &oneConnListener{conn: nil}
5478 srv := Server{}
5479 go func() { srv.Serve(ln1) }()
5480 go func() { srv.Serve(ln2) }()
5481 }
5482 }
5483
5484 func TestServerIdleTimeout(t *testing.T) {
5485 if testing.Short() {
5486 t.Skip("skipping in short mode")
5487 }
5488 setParallel(t)
5489 defer afterTest(t)
5490 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
5491 io.Copy(io.Discard, r.Body)
5492 io.WriteString(w, r.RemoteAddr)
5493 }))
5494 ts.Config.ReadHeaderTimeout = 1 * time.Second
5495 ts.Config.IdleTimeout = 2 * time.Second
5496 ts.Start()
5497 defer ts.Close()
5498 c := ts.Client()
5499
5500 get := func() string {
5501 res, err := c.Get(ts.URL)
5502 if err != nil {
5503 t.Fatal(err)
5504 }
5505 defer res.Body.Close()
5506 slurp, err := io.ReadAll(res.Body)
5507 if err != nil {
5508 t.Fatal(err)
5509 }
5510 return string(slurp)
5511 }
5512
5513 a1, a2 := get(), get()
5514 if a1 != a2 {
5515 t.Fatalf("did requests on different connections")
5516 }
5517 time.Sleep(3 * time.Second)
5518 a3 := get()
5519 if a2 == a3 {
5520 t.Fatal("request three unexpectedly on same connection")
5521 }
5522
5523
5524 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5525 if err != nil {
5526 t.Fatal(err)
5527 }
5528 defer conn.Close()
5529 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5530 time.Sleep(2 * time.Second)
5531 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5532 t.Fatal("copy byte succeeded; want err")
5533 }
5534 }
5535
5536 func get(t *testing.T, c *Client, url string) string {
5537 res, err := c.Get(url)
5538 if err != nil {
5539 t.Fatal(err)
5540 }
5541 defer res.Body.Close()
5542 slurp, err := io.ReadAll(res.Body)
5543 if err != nil {
5544 t.Fatal(err)
5545 }
5546 return string(slurp)
5547 }
5548
5549
5550
5551 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5552 setParallel(t)
5553 defer afterTest(t)
5554 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
5555 io.WriteString(w, r.RemoteAddr)
5556 }))
5557 defer ts.Close()
5558
5559 c := ts.Client()
5560 tr := c.Transport.(*Transport)
5561
5562 get := func() string { return get(t, c, ts.URL) }
5563
5564 a1, a2 := get(), get()
5565 if a1 != a2 {
5566 t.Fatal("expected first two requests on same connection")
5567 }
5568 addr := strings.TrimPrefix(ts.URL, "http://")
5569
5570
5571
5572
5573
5574 n := tr.IdleConnCountForTesting("http", addr)
5575 if n != 1 {
5576 t.Fatalf("idle count for %q after 2 gets = %d, want 1", addr, n)
5577 }
5578
5579
5580 ts.Config.SetKeepAlivesEnabled(false)
5581
5582 var idle1 int
5583 if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool {
5584 idle1 = tr.IdleConnCountForTesting("http", addr)
5585 return idle1 == 0
5586 }) {
5587 t.Fatalf("idle count after SetKeepAlivesEnabled called = %v; want 0", idle1)
5588 }
5589
5590 a3 := get()
5591 if a3 == a2 {
5592 t.Fatal("expected third request on new connection")
5593 }
5594 }
5595
5596 func TestServerShutdown_h1(t *testing.T) {
5597 testServerShutdown(t, h1Mode)
5598 }
5599 func TestServerShutdown_h2(t *testing.T) {
5600 testServerShutdown(t, h2Mode)
5601 }
5602
5603 func testServerShutdown(t *testing.T, h2 bool) {
5604 setParallel(t)
5605 defer afterTest(t)
5606 var doShutdown func()
5607 var doStateCount func()
5608 var shutdownRes = make(chan error, 1)
5609 var statesRes = make(chan map[ConnState]int, 1)
5610 var gotOnShutdown = make(chan struct{}, 1)
5611 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5612 doStateCount()
5613 go doShutdown()
5614
5615
5616
5617
5618 time.Sleep(20 * time.Millisecond)
5619 io.WriteString(w, r.RemoteAddr)
5620 })
5621 cst := newClientServerTest(t, h2, handler, func(srv *httptest.Server) {
5622 srv.Config.RegisterOnShutdown(func() { gotOnShutdown <- struct{}{} })
5623 })
5624 defer cst.close()
5625
5626 doShutdown = func() {
5627 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5628 }
5629 doStateCount = func() {
5630 statesRes <- cst.ts.Config.ExportAllConnsByState()
5631 }
5632 get(t, cst.c, cst.ts.URL)
5633
5634 if err := <-shutdownRes; err != nil {
5635 t.Fatalf("Shutdown: %v", err)
5636 }
5637 select {
5638 case <-gotOnShutdown:
5639 case <-time.After(5 * time.Second):
5640 t.Errorf("onShutdown callback not called, RegisterOnShutdown broken?")
5641 }
5642
5643 if states := <-statesRes; states[StateActive] != 1 {
5644 t.Errorf("connection in wrong state, %v", states)
5645 }
5646
5647 res, err := cst.c.Get(cst.ts.URL)
5648 if err == nil {
5649 res.Body.Close()
5650 t.Fatal("second request should fail. server should be shut down")
5651 }
5652 }
5653
5654 func TestServerShutdownStateNew(t *testing.T) {
5655 if testing.Short() {
5656 t.Skip("test takes 5-6 seconds; skipping in short mode")
5657 }
5658 setParallel(t)
5659 defer afterTest(t)
5660
5661 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
5662
5663 }))
5664 var connAccepted sync.WaitGroup
5665 ts.Config.ConnState = func(conn net.Conn, state ConnState) {
5666 if state == StateNew {
5667 connAccepted.Done()
5668 }
5669 }
5670 ts.Start()
5671 defer ts.Close()
5672
5673
5674 connAccepted.Add(1)
5675 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5676 if err != nil {
5677 t.Fatal(err)
5678 }
5679 defer c.Close()
5680
5681
5682
5683
5684
5685 connAccepted.Wait()
5686
5687 shutdownRes := make(chan error, 1)
5688 go func() {
5689 shutdownRes <- ts.Config.Shutdown(context.Background())
5690 }()
5691 readRes := make(chan error, 1)
5692 go func() {
5693 _, err := c.Read([]byte{0})
5694 readRes <- err
5695 }()
5696
5697 const expectTimeout = 5 * time.Second
5698 t0 := time.Now()
5699 select {
5700 case got := <-shutdownRes:
5701 d := time.Since(t0)
5702 if got != nil {
5703 t.Fatalf("shutdown error after %v: %v", d, err)
5704 }
5705 if d < expectTimeout/2 {
5706 t.Errorf("shutdown too soon after %v", d)
5707 }
5708 case <-time.After(expectTimeout * 3 / 2):
5709 t.Fatalf("timeout waiting for shutdown")
5710 }
5711
5712
5713
5714 select {
5715 case err := <-readRes:
5716 if err == nil {
5717 t.Error("expected error from Read")
5718 }
5719 case <-time.After(2 * time.Second):
5720 t.Errorf("timeout waiting for Read to unblock")
5721 }
5722 }
5723
5724
5725 func TestServerCloseDeadlock(t *testing.T) {
5726 var s Server
5727 s.Close()
5728 s.Close()
5729 }
5730
5731
5732
5733 func TestServerKeepAlivesEnabled_h1(t *testing.T) { testServerKeepAlivesEnabled(t, h1Mode) }
5734 func TestServerKeepAlivesEnabled_h2(t *testing.T) { testServerKeepAlivesEnabled(t, h2Mode) }
5735 func testServerKeepAlivesEnabled(t *testing.T, h2 bool) {
5736 if h2 {
5737 restore := ExportSetH2GoawayTimeout(10 * time.Millisecond)
5738 defer restore()
5739 }
5740
5741 defer afterTest(t)
5742 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {}))
5743 defer cst.close()
5744 srv := cst.ts.Config
5745 srv.SetKeepAlivesEnabled(false)
5746 for try := 0; try < 2; try++ {
5747 if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) {
5748 t.Fatalf("request %v: test server has active conns", try)
5749 }
5750 conns := 0
5751 var info httptrace.GotConnInfo
5752 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5753 GotConn: func(v httptrace.GotConnInfo) {
5754 conns++
5755 info = v
5756 },
5757 })
5758 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
5759 if err != nil {
5760 t.Fatal(err)
5761 }
5762 res, err := cst.c.Do(req)
5763 if err != nil {
5764 t.Fatal(err)
5765 }
5766 res.Body.Close()
5767 if conns != 1 {
5768 t.Fatalf("request %v: got %v conns, want 1", try, conns)
5769 }
5770 if info.Reused || info.WasIdle {
5771 t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
5772 }
5773 }
5774 }
5775
5776
5777
5778
5779 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) {
5780 setParallel(t)
5781 defer afterTest(t)
5782 runTimeSensitiveTest(t, []time.Duration{
5783 10 * time.Millisecond,
5784 50 * time.Millisecond,
5785 250 * time.Millisecond,
5786 time.Second,
5787 2 * time.Second,
5788 }, func(t *testing.T, timeout time.Duration) error {
5789 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
5790 select {
5791 case <-time.After(2 * timeout):
5792 fmt.Fprint(w, "ok")
5793 case <-r.Context().Done():
5794 fmt.Fprint(w, r.Context().Err())
5795 }
5796 }))
5797 ts.Config.ReadTimeout = timeout
5798 ts.Start()
5799 defer ts.Close()
5800
5801 c := ts.Client()
5802
5803 res, err := c.Get(ts.URL)
5804 if err != nil {
5805 return fmt.Errorf("Get: %v", err)
5806 }
5807 slurp, err := io.ReadAll(res.Body)
5808 res.Body.Close()
5809 if err != nil {
5810 return fmt.Errorf("Body ReadAll: %v", err)
5811 }
5812 if string(slurp) != "ok" {
5813 return fmt.Errorf("got: %q, want ok", slurp)
5814 }
5815 return nil
5816 })
5817 }
5818
5819
5820
5821 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
5822 for i, d := range durations {
5823 err := test(t, d)
5824 if err == nil {
5825 return
5826 }
5827 if i == len(durations)-1 {
5828 t.Fatalf("failed with duration %v: %v", d, err)
5829 }
5830 }
5831 }
5832
5833
5834
5835 func TestServerDuplicateBackgroundRead(t *testing.T) {
5836 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
5837 testenv.SkipFlaky(t, 24826)
5838 }
5839
5840 setParallel(t)
5841 defer afterTest(t)
5842
5843 goroutines := 5
5844 requests := 2000
5845 if testing.Short() {
5846 goroutines = 3
5847 requests = 100
5848 }
5849
5850 hts := httptest.NewServer(HandlerFunc(NotFound))
5851 defer hts.Close()
5852
5853 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
5854
5855 var wg sync.WaitGroup
5856 for i := 0; i < goroutines; i++ {
5857 wg.Add(1)
5858 go func() {
5859 defer wg.Done()
5860 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
5861 if err != nil {
5862 t.Error(err)
5863 return
5864 }
5865 defer cn.Close()
5866
5867 wg.Add(1)
5868 go func() {
5869 defer wg.Done()
5870 io.Copy(io.Discard, cn)
5871 }()
5872
5873 for j := 0; j < requests; j++ {
5874 if t.Failed() {
5875 return
5876 }
5877 _, err := cn.Write(reqBytes)
5878 if err != nil {
5879 t.Error(err)
5880 return
5881 }
5882 }
5883 }()
5884 }
5885 wg.Wait()
5886 }
5887
5888
5889
5890
5891
5892
5893 func TestServerHijackGetsBackgroundByte(t *testing.T) {
5894 if runtime.GOOS == "plan9" {
5895 t.Skip("skipping test; see https://golang.org/issue/18657")
5896 }
5897 setParallel(t)
5898 defer afterTest(t)
5899 done := make(chan struct{})
5900 inHandler := make(chan bool, 1)
5901 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
5902 defer close(done)
5903
5904
5905 inHandler <- true
5906
5907 conn, buf, err := w.(Hijacker).Hijack()
5908 if err != nil {
5909 t.Error(err)
5910 return
5911 }
5912 defer conn.Close()
5913
5914 peek, err := buf.Reader.Peek(3)
5915 if string(peek) != "foo" || err != nil {
5916 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
5917 }
5918
5919 select {
5920 case <-r.Context().Done():
5921 t.Error("context unexpectedly canceled")
5922 default:
5923 }
5924 }))
5925 defer ts.Close()
5926
5927 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
5928 if err != nil {
5929 t.Fatal(err)
5930 }
5931 defer cn.Close()
5932 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
5933 t.Fatal(err)
5934 }
5935 <-inHandler
5936 if _, err := cn.Write([]byte("foo")); err != nil {
5937 t.Fatal(err)
5938 }
5939
5940 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
5941 t.Fatal(err)
5942 }
5943 select {
5944 case <-done:
5945 case <-time.After(2 * time.Second):
5946 t.Error("timeout")
5947 }
5948 }
5949
5950
5951
5952
5953 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
5954 if runtime.GOOS == "plan9" {
5955 t.Skip("skipping test; see https://golang.org/issue/18657")
5956 }
5957 setParallel(t)
5958 defer afterTest(t)
5959 done := make(chan struct{})
5960 const size = 8 << 10
5961 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
5962 defer close(done)
5963
5964 conn, buf, err := w.(Hijacker).Hijack()
5965 if err != nil {
5966 t.Error(err)
5967 return
5968 }
5969 defer conn.Close()
5970 slurp, err := io.ReadAll(buf.Reader)
5971 if err != nil {
5972 t.Errorf("Copy: %v", err)
5973 }
5974 allX := true
5975 for _, v := range slurp {
5976 if v != 'x' {
5977 allX = false
5978 }
5979 }
5980 if len(slurp) != size {
5981 t.Errorf("read %d; want %d", len(slurp), size)
5982 } else if !allX {
5983 t.Errorf("read %q; want %d 'x'", slurp, size)
5984 }
5985 }))
5986 defer ts.Close()
5987
5988 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
5989 if err != nil {
5990 t.Fatal(err)
5991 }
5992 defer cn.Close()
5993 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
5994 strings.Repeat("x", size)); err != nil {
5995 t.Fatal(err)
5996 }
5997 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
5998 t.Fatal(err)
5999 }
6000
6001 <-done
6002 }
6003
6004
6005 func TestServerValidatesMethod(t *testing.T) {
6006 tests := []struct {
6007 method string
6008 want int
6009 }{
6010 {"GET", 200},
6011 {"GE(T", 400},
6012 }
6013 for _, tt := range tests {
6014 conn := &testConn{closec: make(chan bool, 1)}
6015 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
6016
6017 ln := &oneConnListener{conn}
6018 go Serve(ln, serve(200))
6019 <-conn.closec
6020 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
6021 if err != nil {
6022 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
6023 continue
6024 }
6025 if res.StatusCode != tt.want {
6026 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
6027 }
6028 }
6029 }
6030
6031
6032 type eofListenerNotComparable []int
6033
6034 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
6035 func (eofListenerNotComparable) Addr() net.Addr { return nil }
6036 func (eofListenerNotComparable) Close() error { return nil }
6037
6038
6039 func TestServerListenNotComparableListener(t *testing.T) {
6040 var s Server
6041 s.Serve(make(eofListenerNotComparable, 1))
6042 }
6043
6044
6045 type countCloseListener struct {
6046 net.Listener
6047 closes int32
6048 }
6049
6050 func (p *countCloseListener) Close() error {
6051 var err error
6052 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
6053 err = p.Listener.Close()
6054 }
6055 return err
6056 }
6057
6058
6059 func TestServerCloseListenerOnce(t *testing.T) {
6060 setParallel(t)
6061 defer afterTest(t)
6062
6063 ln := newLocalListener(t)
6064 defer ln.Close()
6065
6066 cl := &countCloseListener{Listener: ln}
6067 server := &Server{}
6068 sdone := make(chan bool, 1)
6069
6070 go func() {
6071 server.Serve(cl)
6072 sdone <- true
6073 }()
6074 time.Sleep(10 * time.Millisecond)
6075 server.Shutdown(context.Background())
6076 ln.Close()
6077 <-sdone
6078
6079 nclose := atomic.LoadInt32(&cl.closes)
6080 if nclose != 1 {
6081 t.Errorf("Close calls = %v; want 1", nclose)
6082 }
6083 }
6084
6085
6086 func TestServerShutdownThenServe(t *testing.T) {
6087 var srv Server
6088 cl := &countCloseListener{Listener: nil}
6089 srv.Shutdown(context.Background())
6090 got := srv.Serve(cl)
6091 if got != ErrServerClosed {
6092 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6093 }
6094 nclose := atomic.LoadInt32(&cl.closes)
6095 if nclose != 1 {
6096 t.Errorf("Close calls = %v; want 1", nclose)
6097 }
6098 }
6099
6100
6101 func TestStripPortFromHost(t *testing.T) {
6102 mux := NewServeMux()
6103
6104 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6105 fmt.Fprintf(w, "OK")
6106 })
6107 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6108 fmt.Fprintf(w, "uh-oh!")
6109 })
6110
6111 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6112 rw := httptest.NewRecorder()
6113
6114 mux.ServeHTTP(rw, req)
6115
6116 response := rw.Body.String()
6117 if response != "OK" {
6118 t.Errorf("Response gotten was %q", response)
6119 }
6120 }
6121
6122 func TestServerContexts(t *testing.T) {
6123 setParallel(t)
6124 defer afterTest(t)
6125 type baseKey struct{}
6126 type connKey struct{}
6127 ch := make(chan context.Context, 1)
6128 ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
6129 ch <- r.Context()
6130 }))
6131 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6132 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6133 t.Errorf("unexpected onceClose listener type %T", ln)
6134 }
6135 return context.WithValue(context.Background(), baseKey{}, "base")
6136 }
6137 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6138 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6139 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6140 }
6141 return context.WithValue(ctx, connKey{}, "conn")
6142 }
6143 ts.Start()
6144 defer ts.Close()
6145 res, err := ts.Client().Get(ts.URL)
6146 if err != nil {
6147 t.Fatal(err)
6148 }
6149 res.Body.Close()
6150 ctx := <-ch
6151 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6152 t.Errorf("base context key = %#v; want %q", got, want)
6153 }
6154 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6155 t.Errorf("conn context key = %#v; want %q", got, want)
6156 }
6157 }
6158
6159 func TestServerContextsHTTP2(t *testing.T) {
6160 setParallel(t)
6161 defer afterTest(t)
6162 type baseKey struct{}
6163 type connKey struct{}
6164 ch := make(chan context.Context, 1)
6165 ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
6166 if r.ProtoMajor != 2 {
6167 t.Errorf("unexpected HTTP/1.x request")
6168 }
6169 ch <- r.Context()
6170 }))
6171 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6172 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6173 t.Errorf("unexpected onceClose listener type %T", ln)
6174 }
6175 return context.WithValue(context.Background(), baseKey{}, "base")
6176 }
6177 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6178 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6179 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6180 }
6181 return context.WithValue(ctx, connKey{}, "conn")
6182 }
6183 ts.TLS = &tls.Config{
6184 NextProtos: []string{"h2", "http/1.1"},
6185 }
6186 ts.StartTLS()
6187 defer ts.Close()
6188 ts.Client().Transport.(*Transport).ForceAttemptHTTP2 = true
6189 res, err := ts.Client().Get(ts.URL)
6190 if err != nil {
6191 t.Fatal(err)
6192 }
6193 res.Body.Close()
6194 ctx := <-ch
6195 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6196 t.Errorf("base context key = %#v; want %q", got, want)
6197 }
6198 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6199 t.Errorf("conn context key = %#v; want %q", got, want)
6200 }
6201 }
6202
6203
6204 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6205 setParallel(t)
6206 defer afterTest(t)
6207 type connKey struct{}
6208 ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
6209 rw.Header().Set("Connection", "close")
6210 }))
6211 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6212 if got := ctx.Value(connKey{}); got != nil {
6213 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6214 }
6215 return context.WithValue(ctx, connKey{}, "conn")
6216 }
6217 ts.Start()
6218 defer ts.Close()
6219
6220 var res *Response
6221 var err error
6222
6223 res, err = ts.Client().Get(ts.URL)
6224 if err != nil {
6225 t.Fatal(err)
6226 }
6227 res.Body.Close()
6228
6229 res, err = ts.Client().Get(ts.URL)
6230 if err != nil {
6231 t.Fatal(err)
6232 }
6233 res.Body.Close()
6234 }
6235
6236
6237
6238 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6239 cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
6240 w.Write([]byte("Hello, World!"))
6241 }))
6242 defer cst.Close()
6243
6244 serverURL, err := url.Parse(cst.URL)
6245 if err != nil {
6246 t.Fatalf("Failed to parse server URL: %v", err)
6247 }
6248
6249 unsupportedTEs := []string{
6250 "fugazi",
6251 "foo-bar",
6252 "unknown",
6253 }
6254
6255 for _, badTE := range unsupportedTEs {
6256 http1ReqBody := fmt.Sprintf(""+
6257 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6258 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6259
6260 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6261 if err != nil {
6262 t.Errorf("%q. unexpected error: %v", badTE, err)
6263 continue
6264 }
6265
6266 wantBody := fmt.Sprintf("" +
6267 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6268 "Connection: close\r\n\r\nUnsupported transfer encoding")
6269
6270 if string(gotBody) != wantBody {
6271 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6272 }
6273 }
6274 }
6275
6276 func TestContentEncodingNoSniffing_h1(t *testing.T) {
6277 testContentEncodingNoSniffing(t, h1Mode)
6278 }
6279
6280 func TestContentEncodingNoSniffing_h2(t *testing.T) {
6281 testContentEncodingNoSniffing(t, h2Mode)
6282 }
6283
6284
6285 func testContentEncodingNoSniffing(t *testing.T, h2 bool) {
6286 setParallel(t)
6287 defer afterTest(t)
6288
6289 type setting struct {
6290 name string
6291 body []byte
6292
6293
6294
6295
6296 contentEncoding any
6297 wantContentType string
6298 }
6299
6300 settings := []*setting{
6301 {
6302 name: "gzip content-encoding, gzipped",
6303 contentEncoding: "application/gzip",
6304 wantContentType: "",
6305 body: func() []byte {
6306 buf := new(bytes.Buffer)
6307 gzw := gzip.NewWriter(buf)
6308 gzw.Write([]byte("doctype html><p>Hello</p>"))
6309 gzw.Close()
6310 return buf.Bytes()
6311 }(),
6312 },
6313 {
6314 name: "zlib content-encoding, zlibbed",
6315 contentEncoding: "application/zlib",
6316 wantContentType: "",
6317 body: func() []byte {
6318 buf := new(bytes.Buffer)
6319 zw := zlib.NewWriter(buf)
6320 zw.Write([]byte("doctype html><p>Hello</p>"))
6321 zw.Close()
6322 return buf.Bytes()
6323 }(),
6324 },
6325 {
6326 name: "no content-encoding",
6327 wantContentType: "application/x-gzip",
6328 body: func() []byte {
6329 buf := new(bytes.Buffer)
6330 gzw := gzip.NewWriter(buf)
6331 gzw.Write([]byte("doctype html><p>Hello</p>"))
6332 gzw.Close()
6333 return buf.Bytes()
6334 }(),
6335 },
6336 {
6337 name: "phony content-encoding",
6338 contentEncoding: "foo/bar",
6339 body: []byte("doctype html><p>Hello</p>"),
6340 },
6341 {
6342 name: "empty but set content-encoding",
6343 contentEncoding: "",
6344 wantContentType: "audio/mpeg",
6345 body: []byte("ID3"),
6346 },
6347 }
6348
6349 for _, tt := range settings {
6350 t.Run(tt.name, func(t *testing.T) {
6351 cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) {
6352 if tt.contentEncoding != nil {
6353 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6354 }
6355 rw.Write(tt.body)
6356 }))
6357 defer cst.close()
6358
6359 res, err := cst.c.Get(cst.ts.URL)
6360 if err != nil {
6361 t.Fatalf("Failed to fetch URL: %v", err)
6362 }
6363 defer res.Body.Close()
6364
6365 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6366 if w != nil {
6367 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6368 } else if g != "" {
6369 t.Errorf("Unexpected Content-Encoding %q", g)
6370 }
6371 }
6372
6373 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6374 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6375 }
6376 })
6377 }
6378 }
6379
6380
6381
6382 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6383 if testing.Short() {
6384 t.Skip("skipping in short mode")
6385 }
6386
6387 setParallel(t)
6388 defer afterTest(t)
6389
6390 pc, curFile, _, _ := runtime.Caller(0)
6391 curFileBaseName := filepath.Base(curFile)
6392 testFuncName := runtime.FuncForPC(pc).Name()
6393
6394 timeoutMsg := "timed out here!"
6395
6396 tests := []struct {
6397 name string
6398 mustTimeout bool
6399 wantResp string
6400 }{
6401 {
6402 name: "return before timeout",
6403 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6404 },
6405 {
6406 name: "return after timeout",
6407 mustTimeout: true,
6408 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6409 len(timeoutMsg), timeoutMsg),
6410 },
6411 }
6412
6413 for _, tt := range tests {
6414 tt := tt
6415 t.Run(tt.name, func(t *testing.T) {
6416 exitHandler := make(chan bool, 1)
6417 defer close(exitHandler)
6418 lastLine := make(chan int, 1)
6419
6420 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6421 w.WriteHeader(404)
6422 w.WriteHeader(404)
6423 w.WriteHeader(404)
6424 w.WriteHeader(404)
6425 _, _, line, _ := runtime.Caller(0)
6426 lastLine <- line
6427 <-exitHandler
6428 })
6429
6430 if !tt.mustTimeout {
6431 exitHandler <- true
6432 }
6433
6434 logBuf := new(bytes.Buffer)
6435 srvLog := log.New(logBuf, "", 0)
6436
6437 dur := 20 * time.Millisecond
6438 if !tt.mustTimeout {
6439
6440 dur = 10 * time.Second
6441 }
6442 th := TimeoutHandler(sh, dur, timeoutMsg)
6443 cst := newClientServerTest(t, h1Mode , th, optWithServerLog(srvLog))
6444 defer cst.close()
6445
6446 res, err := cst.c.Get(cst.ts.URL)
6447 if err != nil {
6448 t.Fatalf("Unexpected error: %v", err)
6449 }
6450
6451
6452
6453 res.Header.Del("Date")
6454 res.Header.Del("Content-Type")
6455
6456
6457 blob, _ := httputil.DumpResponse(res, true)
6458 if g, w := string(blob), tt.wantResp; g != w {
6459 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6460 }
6461
6462
6463
6464 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6465 if g, w := len(logEntries), 3; g != w {
6466 blob, _ := json.MarshalIndent(logEntries, "", " ")
6467 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6468 }
6469
6470 lastSpuriousLine := <-lastLine
6471 firstSpuriousLine := lastSpuriousLine - 3
6472
6473
6474 for i, logEntry := range logEntries {
6475 wantLine := firstSpuriousLine + i
6476 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6477 testFuncName, curFileBaseName, wantLine)
6478 re := regexp.MustCompile(pat)
6479 if !re.MatchString(logEntry) {
6480 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6481 }
6482 }
6483 })
6484 }
6485 }
6486
6487
6488
6489
6490 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6491 conn, err := net.Dial("tcp", host)
6492 if err != nil {
6493 return nil, err
6494 }
6495 defer conn.Close()
6496
6497 if _, err := conn.Write(http1ReqBody); err != nil {
6498 return nil, err
6499 }
6500 return io.ReadAll(conn)
6501 }
6502
6503 func BenchmarkResponseStatusLine(b *testing.B) {
6504 b.ReportAllocs()
6505 b.RunParallel(func(pb *testing.PB) {
6506 bw := bufio.NewWriter(io.Discard)
6507 var buf3 [3]byte
6508 for pb.Next() {
6509 Export_writeStatusLine(bw, true, 200, buf3[:])
6510 }
6511 })
6512 }
6513 func TestDisableKeepAliveUpgrade(t *testing.T) {
6514 if testing.Short() {
6515 t.Skip("skipping in short mode")
6516 }
6517
6518 setParallel(t)
6519 defer afterTest(t)
6520
6521 s := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
6522 w.Header().Set("Connection", "Upgrade")
6523 w.Header().Set("Upgrade", "someProto")
6524 w.WriteHeader(StatusSwitchingProtocols)
6525 c, buf, err := w.(Hijacker).Hijack()
6526 if err != nil {
6527 return
6528 }
6529 defer c.Close()
6530
6531
6532
6533 io.Copy(c, buf)
6534 }))
6535 s.Config.SetKeepAlivesEnabled(false)
6536 s.Start()
6537 defer s.Close()
6538
6539 cl := s.Client()
6540 cl.Transport.(*Transport).DisableKeepAlives = true
6541
6542 resp, err := cl.Get(s.URL)
6543 if err != nil {
6544 t.Fatalf("failed to perform request: %v", err)
6545 }
6546 defer resp.Body.Close()
6547
6548 if resp.StatusCode != StatusSwitchingProtocols {
6549 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6550 }
6551
6552 rwc, ok := resp.Body.(io.ReadWriteCloser)
6553 if !ok {
6554 t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
6555 }
6556
6557 _, err = rwc.Write([]byte("hello"))
6558 if err != nil {
6559 t.Fatalf("failed to write to body: %v", err)
6560 }
6561
6562 b := make([]byte, 5)
6563 _, err = io.ReadFull(rwc, b)
6564 if err != nil {
6565 t.Fatalf("failed to read from body: %v", err)
6566 }
6567
6568 if string(b) != "hello" {
6569 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6570 }
6571 }
6572
6573 func TestMuxRedirectRelative(t *testing.T) {
6574 setParallel(t)
6575 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
6576 if err != nil {
6577 t.Errorf("%s", err)
6578 }
6579 mux := NewServeMux()
6580 resp := httptest.NewRecorder()
6581 mux.ServeHTTP(resp, req)
6582 if got, want := resp.Header().Get("Location"), "/"; got != want {
6583 t.Errorf("Location header expected %q; got %q", want, got)
6584 }
6585 if got, want := resp.Code, StatusMovedPermanently; got != want {
6586 t.Errorf("Expected response code %d; got %d", want, got)
6587 }
6588 }
6589
6590
6591 func TestQuerySemicolon(t *testing.T) {
6592 t.Cleanup(func() { afterTest(t) })
6593
6594 tests := []struct {
6595 query string
6596 xNoSemicolons string
6597 xWithSemicolons string
6598 warning bool
6599 }{
6600 {"?a=1;x=bad&x=good", "good", "bad", true},
6601 {"?a=1;b=bad&x=good", "good", "good", true},
6602 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
6603 {"?a=1;x=good;x=bad", "", "good", true},
6604 }
6605
6606 for _, tt := range tests {
6607 t.Run(tt.query+"/allow=false", func(t *testing.T) {
6608 allowSemicolons := false
6609 testQuerySemicolon(t, tt.query, tt.xNoSemicolons, allowSemicolons, tt.warning)
6610 })
6611 t.Run(tt.query+"/allow=true", func(t *testing.T) {
6612 allowSemicolons, expectWarning := true, false
6613 testQuerySemicolon(t, tt.query, tt.xWithSemicolons, allowSemicolons, expectWarning)
6614 })
6615 }
6616 }
6617
6618 func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolons, expectWarning bool) {
6619 setParallel(t)
6620
6621 writeBackX := func(w ResponseWriter, r *Request) {
6622 x := r.URL.Query().Get("x")
6623 if expectWarning {
6624 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
6625 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
6626 }
6627 } else {
6628 if err := r.ParseForm(); err != nil {
6629 t.Errorf("expected no error from ParseForm, got %v", err)
6630 }
6631 }
6632 if got := r.FormValue("x"); x != got {
6633 t.Errorf("got %q from FormValue, want %q", got, x)
6634 }
6635 fmt.Fprintf(w, "%s", x)
6636 }
6637
6638 h := Handler(HandlerFunc(writeBackX))
6639 if allowSemicolons {
6640 h = AllowQuerySemicolons(h)
6641 }
6642
6643 ts := httptest.NewUnstartedServer(h)
6644 logBuf := &bytes.Buffer{}
6645 ts.Config.ErrorLog = log.New(logBuf, "", 0)
6646 ts.Start()
6647 defer ts.Close()
6648
6649 req, _ := NewRequest("GET", ts.URL+query, nil)
6650 res, err := ts.Client().Do(req)
6651 if err != nil {
6652 t.Fatal(err)
6653 }
6654 slurp, _ := io.ReadAll(res.Body)
6655 res.Body.Close()
6656 if got, want := res.StatusCode, 200; got != want {
6657 t.Errorf("Status = %d; want = %d", got, want)
6658 }
6659 if got, want := string(slurp), wantX; got != want {
6660 t.Errorf("Body = %q; want = %q", got, want)
6661 }
6662
6663 if expectWarning {
6664 if !strings.Contains(logBuf.String(), "semicolon") {
6665 t.Errorf("got %q from ErrorLog, expected a mention of semicolons", logBuf.String())
6666 }
6667 } else {
6668 if strings.Contains(logBuf.String(), "semicolon") {
6669 t.Errorf("got %q from ErrorLog, expected no mention of semicolons", logBuf.String())
6670 }
6671 }
6672 }
6673
6674 func TestMaxBytesHandler(t *testing.T) {
6675 setParallel(t)
6676 defer afterTest(t)
6677
6678 for _, maxSize := range []int64{100, 1_000, 1_000_000} {
6679 for _, requestSize := range []int64{100, 1_000, 1_000_000} {
6680 t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
6681 func(t *testing.T) {
6682 testMaxBytesHandler(t, maxSize, requestSize)
6683 })
6684 }
6685 }
6686 }
6687
6688 func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) {
6689 var (
6690 handlerN int64
6691 handlerErr error
6692 )
6693 echo := HandlerFunc(func(w ResponseWriter, r *Request) {
6694 var buf bytes.Buffer
6695 handlerN, handlerErr = io.Copy(&buf, r.Body)
6696 io.Copy(w, &buf)
6697 })
6698
6699 ts := httptest.NewServer(MaxBytesHandler(echo, maxSize))
6700 defer ts.Close()
6701
6702 c := ts.Client()
6703 var buf strings.Builder
6704 body := strings.NewReader(strings.Repeat("a", int(requestSize)))
6705 res, err := c.Post(ts.URL, "text/plain", body)
6706 if err != nil {
6707 t.Errorf("unexpected connection error: %v", err)
6708 } else {
6709 _, err = io.Copy(&buf, res.Body)
6710 res.Body.Close()
6711 if err != nil {
6712 t.Errorf("unexpected read error: %v", err)
6713 }
6714 }
6715 if handlerN > maxSize {
6716 t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
6717 }
6718 if requestSize > maxSize && handlerErr == nil {
6719 t.Error("expected error on handler side; got nil")
6720 }
6721 if requestSize <= maxSize {
6722 if handlerErr != nil {
6723 t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
6724 }
6725 if handlerN != requestSize {
6726 t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
6727 }
6728 }
6729 if buf.Len() != int(handlerN) {
6730 t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
6731 }
6732 }
6733
View as plain text