Source file
src/net/http/clientserver_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bytes"
11 "compress/gzip"
12 "crypto/rand"
13 "crypto/sha1"
14 "crypto/tls"
15 "fmt"
16 "hash"
17 "io"
18 "log"
19 "net"
20 . "net/http"
21 "net/http/httptest"
22 "net/http/httputil"
23 "net/url"
24 "os"
25 "reflect"
26 "runtime"
27 "sort"
28 "strings"
29 "sync"
30 "sync/atomic"
31 "testing"
32 "time"
33 )
34
35 type clientServerTest struct {
36 t *testing.T
37 h2 bool
38 h Handler
39 ts *httptest.Server
40 tr *Transport
41 c *Client
42 }
43
44 func (t *clientServerTest) close() {
45 t.tr.CloseIdleConnections()
46 t.ts.Close()
47 }
48
49 func (t *clientServerTest) getURL(u string) string {
50 res, err := t.c.Get(u)
51 if err != nil {
52 t.t.Fatal(err)
53 }
54 defer res.Body.Close()
55 slurp, err := io.ReadAll(res.Body)
56 if err != nil {
57 t.t.Fatal(err)
58 }
59 return string(slurp)
60 }
61
62 func (t *clientServerTest) scheme() string {
63 if t.h2 {
64 return "https"
65 }
66 return "http"
67 }
68
69 const (
70 h1Mode = false
71 h2Mode = true
72 )
73
74 var optQuietLog = func(ts *httptest.Server) {
75 ts.Config.ErrorLog = quietLog
76 }
77
78 func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
79 return func(ts *httptest.Server) {
80 ts.Config.ErrorLog = lg
81 }
82 }
83
84 func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...any) *clientServerTest {
85 if h2 {
86 CondSkipHTTP2(t)
87 }
88 cst := &clientServerTest{
89 t: t,
90 h2: h2,
91 h: h,
92 tr: &Transport{},
93 }
94 cst.c = &Client{Transport: cst.tr}
95 cst.ts = httptest.NewUnstartedServer(h)
96
97 for _, opt := range opts {
98 switch opt := opt.(type) {
99 case func(*Transport):
100 opt(cst.tr)
101 case func(*httptest.Server):
102 opt(cst.ts)
103 default:
104 t.Fatalf("unhandled option type %T", opt)
105 }
106 }
107
108 if !h2 {
109 cst.ts.Start()
110 return cst
111 }
112 ExportHttp2ConfigureServer(cst.ts.Config, nil)
113 cst.ts.TLS = cst.ts.Config.TLSConfig
114 cst.ts.StartTLS()
115
116 cst.tr.TLSClientConfig = &tls.Config{
117 InsecureSkipVerify: true,
118 }
119 if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
120 t.Fatal(err)
121 }
122 return cst
123 }
124
125
126 func TestNewClientServerTest(t *testing.T) {
127 var got struct {
128 sync.Mutex
129 log []string
130 }
131 h := HandlerFunc(func(w ResponseWriter, r *Request) {
132 got.Lock()
133 defer got.Unlock()
134 got.log = append(got.log, r.Proto)
135 })
136 for _, v := range [2]bool{false, true} {
137 cst := newClientServerTest(t, v, h)
138 if _, err := cst.c.Head(cst.ts.URL); err != nil {
139 t.Fatal(err)
140 }
141 cst.close()
142 }
143 got.Lock()
144 if want := []string{"HTTP/1.1", "HTTP/2.0"}; !reflect.DeepEqual(got.log, want) {
145 t.Errorf("got %q; want %q", got.log, want)
146 }
147 }
148
149 func TestChunkedResponseHeaders_h1(t *testing.T) { testChunkedResponseHeaders(t, h1Mode) }
150 func TestChunkedResponseHeaders_h2(t *testing.T) { testChunkedResponseHeaders(t, h2Mode) }
151
152 func testChunkedResponseHeaders(t *testing.T, h2 bool) {
153 defer afterTest(t)
154 log.SetOutput(io.Discard)
155 defer log.SetOutput(os.Stderr)
156 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
157 w.Header().Set("Content-Length", "intentional gibberish")
158 w.(Flusher).Flush()
159 fmt.Fprintf(w, "I am a chunked response.")
160 }))
161 defer cst.close()
162
163 res, err := cst.c.Get(cst.ts.URL)
164 if err != nil {
165 t.Fatalf("Get error: %v", err)
166 }
167 defer res.Body.Close()
168 if g, e := res.ContentLength, int64(-1); g != e {
169 t.Errorf("expected ContentLength of %d; got %d", e, g)
170 }
171 wantTE := []string{"chunked"}
172 if h2 {
173 wantTE = nil
174 }
175 if !reflect.DeepEqual(res.TransferEncoding, wantTE) {
176 t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
177 }
178 if got, haveCL := res.Header["Content-Length"]; haveCL {
179 t.Errorf("Unexpected Content-Length: %q", got)
180 }
181 }
182
183 type reqFunc func(c *Client, url string) (*Response, error)
184
185
186
187 type h12Compare struct {
188 Handler func(ResponseWriter, *Request)
189 ReqFunc reqFunc
190 CheckResponse func(proto string, res *Response)
191 EarlyCheckResponse func(proto string, res *Response)
192 Opts []any
193 }
194
195 func (tt h12Compare) reqFunc() reqFunc {
196 if tt.ReqFunc == nil {
197 return (*Client).Get
198 }
199 return tt.ReqFunc
200 }
201
202 func (tt h12Compare) run(t *testing.T) {
203 setParallel(t)
204 cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...)
205 defer cst1.close()
206 cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...)
207 defer cst2.close()
208
209 res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
210 if err != nil {
211 t.Errorf("HTTP/1 request: %v", err)
212 return
213 }
214 res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
215 if err != nil {
216 t.Errorf("HTTP/2 request: %v", err)
217 return
218 }
219
220 if fn := tt.EarlyCheckResponse; fn != nil {
221 fn("HTTP/1.1", res1)
222 fn("HTTP/2.0", res2)
223 }
224
225 tt.normalizeRes(t, res1, "HTTP/1.1")
226 tt.normalizeRes(t, res2, "HTTP/2.0")
227 res1body, res2body := res1.Body, res2.Body
228
229 eres1 := mostlyCopy(res1)
230 eres2 := mostlyCopy(res2)
231 if !reflect.DeepEqual(eres1, eres2) {
232 t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
233 cst1.ts.URL, eres1, cst2.ts.URL, eres2)
234 }
235 if !reflect.DeepEqual(res1body, res2body) {
236 t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
237 }
238 if fn := tt.CheckResponse; fn != nil {
239 res1.Body, res2.Body = res1body, res2body
240 fn("HTTP/1.1", res1)
241 fn("HTTP/2.0", res2)
242 }
243 }
244
245 func mostlyCopy(r *Response) *Response {
246 c := *r
247 c.Body = nil
248 c.TransferEncoding = nil
249 c.TLS = nil
250 c.Request = nil
251 return &c
252 }
253
254 type slurpResult struct {
255 io.ReadCloser
256 body []byte
257 err error
258 }
259
260 func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
261
262 func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
263 if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
264 res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
265 } else {
266 t.Errorf("got %q response; want %q", res.Proto, wantProto)
267 }
268 slurp, err := io.ReadAll(res.Body)
269
270 res.Body.Close()
271 res.Body = slurpResult{
272 ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
273 body: slurp,
274 err: err,
275 }
276 for i, v := range res.Header["Date"] {
277 res.Header["Date"][i] = strings.Repeat("x", len(v))
278 }
279 if res.Request == nil {
280 t.Errorf("for %s, no request", wantProto)
281 }
282 if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
283 t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
284 }
285 }
286
287
288 func TestH12_HeadContentLengthNoBody(t *testing.T) {
289 h12Compare{
290 ReqFunc: (*Client).Head,
291 Handler: func(w ResponseWriter, r *Request) {
292 },
293 }.run(t)
294 }
295
296 func TestH12_HeadContentLengthSmallBody(t *testing.T) {
297 h12Compare{
298 ReqFunc: (*Client).Head,
299 Handler: func(w ResponseWriter, r *Request) {
300 io.WriteString(w, "small")
301 },
302 }.run(t)
303 }
304
305 func TestH12_HeadContentLengthLargeBody(t *testing.T) {
306 h12Compare{
307 ReqFunc: (*Client).Head,
308 Handler: func(w ResponseWriter, r *Request) {
309 chunk := strings.Repeat("x", 512<<10)
310 for i := 0; i < 10; i++ {
311 io.WriteString(w, chunk)
312 }
313 },
314 }.run(t)
315 }
316
317 func TestH12_200NoBody(t *testing.T) {
318 h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
319 }
320
321 func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
322 func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
323 func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
324
325 func testH12_noBody(t *testing.T, status int) {
326 h12Compare{Handler: func(w ResponseWriter, r *Request) {
327 w.WriteHeader(status)
328 }}.run(t)
329 }
330
331 func TestH12_SmallBody(t *testing.T) {
332 h12Compare{Handler: func(w ResponseWriter, r *Request) {
333 io.WriteString(w, "small body")
334 }}.run(t)
335 }
336
337 func TestH12_ExplicitContentLength(t *testing.T) {
338 h12Compare{Handler: func(w ResponseWriter, r *Request) {
339 w.Header().Set("Content-Length", "3")
340 io.WriteString(w, "foo")
341 }}.run(t)
342 }
343
344 func TestH12_FlushBeforeBody(t *testing.T) {
345 h12Compare{Handler: func(w ResponseWriter, r *Request) {
346 w.(Flusher).Flush()
347 io.WriteString(w, "foo")
348 }}.run(t)
349 }
350
351 func TestH12_FlushMidBody(t *testing.T) {
352 h12Compare{Handler: func(w ResponseWriter, r *Request) {
353 io.WriteString(w, "foo")
354 w.(Flusher).Flush()
355 io.WriteString(w, "bar")
356 }}.run(t)
357 }
358
359 func TestH12_Head_ExplicitLen(t *testing.T) {
360 h12Compare{
361 ReqFunc: (*Client).Head,
362 Handler: func(w ResponseWriter, r *Request) {
363 if r.Method != "HEAD" {
364 t.Errorf("unexpected method %q", r.Method)
365 }
366 w.Header().Set("Content-Length", "1235")
367 },
368 }.run(t)
369 }
370
371 func TestH12_Head_ImplicitLen(t *testing.T) {
372 h12Compare{
373 ReqFunc: (*Client).Head,
374 Handler: func(w ResponseWriter, r *Request) {
375 if r.Method != "HEAD" {
376 t.Errorf("unexpected method %q", r.Method)
377 }
378 io.WriteString(w, "foo")
379 },
380 }.run(t)
381 }
382
383 func TestH12_HandlerWritesTooLittle(t *testing.T) {
384 h12Compare{
385 Handler: func(w ResponseWriter, r *Request) {
386 w.Header().Set("Content-Length", "3")
387 io.WriteString(w, "12")
388 },
389 CheckResponse: func(proto string, res *Response) {
390 sr, ok := res.Body.(slurpResult)
391 if !ok {
392 t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
393 return
394 }
395 if sr.err != io.ErrUnexpectedEOF {
396 t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
397 }
398 if string(sr.body) != "12" {
399 t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
400 }
401 },
402 }.run(t)
403 }
404
405
406
407
408
409
410
411 func TestH12_HandlerWritesTooMuch(t *testing.T) {
412 h12Compare{
413 Handler: func(w ResponseWriter, r *Request) {
414 w.Header().Set("Content-Length", "3")
415 w.(Flusher).Flush()
416 io.WriteString(w, "123")
417 w.(Flusher).Flush()
418 n, err := io.WriteString(w, "x")
419 if n > 0 || err == nil {
420 t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err)
421 }
422 },
423 }.run(t)
424 }
425
426
427
428 func TestH12_AutoGzip(t *testing.T) {
429 h12Compare{
430 Handler: func(w ResponseWriter, r *Request) {
431 if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
432 t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
433 }
434 w.Header().Set("Content-Encoding", "gzip")
435 gz := gzip.NewWriter(w)
436 io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
437 gz.Close()
438 },
439 }.run(t)
440 }
441
442 func TestH12_AutoGzip_Disabled(t *testing.T) {
443 h12Compare{
444 Opts: []any{
445 func(tr *Transport) { tr.DisableCompression = true },
446 },
447 Handler: func(w ResponseWriter, r *Request) {
448 fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
449 if ae := r.Header.Get("Accept-Encoding"); ae != "" {
450 t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
451 }
452 },
453 }.run(t)
454 }
455
456
457
458
459 func Test304Responses_h1(t *testing.T) { test304Responses(t, h1Mode) }
460 func Test304Responses_h2(t *testing.T) { test304Responses(t, h2Mode) }
461
462 func test304Responses(t *testing.T, h2 bool) {
463 defer afterTest(t)
464 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
465 w.WriteHeader(StatusNotModified)
466 _, err := w.Write([]byte("illegal body"))
467 if err != ErrBodyNotAllowed {
468 t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
469 }
470 }))
471 defer cst.close()
472 res, err := cst.c.Get(cst.ts.URL)
473 if err != nil {
474 t.Fatal(err)
475 }
476 if len(res.TransferEncoding) > 0 {
477 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
478 }
479 body, err := io.ReadAll(res.Body)
480 if err != nil {
481 t.Error(err)
482 }
483 if len(body) > 0 {
484 t.Errorf("got unexpected body %q", string(body))
485 }
486 }
487
488 func TestH12_ServerEmptyContentLength(t *testing.T) {
489 h12Compare{
490 Handler: func(w ResponseWriter, r *Request) {
491 w.Header()["Content-Type"] = []string{""}
492 io.WriteString(w, "<html><body>hi</body></html>")
493 },
494 }.run(t)
495 }
496
497 func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
498 h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
499 }
500
501 func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
502 h12requestContentLength(t, func() io.Reader { return nil }, 0)
503 }
504
505 func TestH12_RequestContentLength_Unknown(t *testing.T) {
506 h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
507 }
508
509 func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
510 h12Compare{
511 Handler: func(w ResponseWriter, r *Request) {
512 w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
513 fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
514 },
515 ReqFunc: func(c *Client, url string) (*Response, error) {
516 return c.Post(url, "text/plain", bodyfn())
517 },
518 CheckResponse: func(proto string, res *Response) {
519 if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
520 t.Errorf("Proto %q got length %q; want %q", proto, got, want)
521 }
522 },
523 }.run(t)
524 }
525
526
527
528 func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) }
529 func TestCancelRequestMidBody_h2(t *testing.T) { testCancelRequestMidBody(t, h2Mode) }
530 func testCancelRequestMidBody(t *testing.T, h2 bool) {
531 defer afterTest(t)
532 unblock := make(chan bool)
533 didFlush := make(chan bool, 1)
534 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
535 io.WriteString(w, "Hello")
536 w.(Flusher).Flush()
537 didFlush <- true
538 <-unblock
539 io.WriteString(w, ", world.")
540 }))
541 defer cst.close()
542 defer close(unblock)
543
544 req, _ := NewRequest("GET", cst.ts.URL, nil)
545 cancel := make(chan struct{})
546 req.Cancel = cancel
547
548 res, err := cst.c.Do(req)
549 if err != nil {
550 t.Fatal(err)
551 }
552 defer res.Body.Close()
553 <-didFlush
554
555
556
557 firstRead := make([]byte, 10)
558 n, err := res.Body.Read(firstRead)
559 if err != nil {
560 t.Fatal(err)
561 }
562 firstRead = firstRead[:n]
563
564 close(cancel)
565
566 rest, err := io.ReadAll(res.Body)
567 all := string(firstRead) + string(rest)
568 if all != "Hello" {
569 t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
570 }
571 if err != ExportErrRequestCanceled {
572 t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
573 }
574 }
575
576
577 func TestTrailersClientToServer_h1(t *testing.T) { testTrailersClientToServer(t, h1Mode) }
578 func TestTrailersClientToServer_h2(t *testing.T) { testTrailersClientToServer(t, h2Mode) }
579
580 func testTrailersClientToServer(t *testing.T, h2 bool) {
581 defer afterTest(t)
582 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
583 var decl []string
584 for k := range r.Trailer {
585 decl = append(decl, k)
586 }
587 sort.Strings(decl)
588
589 slurp, err := io.ReadAll(r.Body)
590 if err != nil {
591 t.Errorf("Server reading request body: %v", err)
592 }
593 if string(slurp) != "foo" {
594 t.Errorf("Server read request body %q; want foo", slurp)
595 }
596 if r.Trailer == nil {
597 io.WriteString(w, "nil Trailer")
598 } else {
599 fmt.Fprintf(w, "decl: %v, vals: %s, %s",
600 decl,
601 r.Trailer.Get("Client-Trailer-A"),
602 r.Trailer.Get("Client-Trailer-B"))
603 }
604 }))
605 defer cst.close()
606
607 var req *Request
608 req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
609 eofReaderFunc(func() {
610 req.Trailer["Client-Trailer-A"] = []string{"valuea"}
611 }),
612 strings.NewReader("foo"),
613 eofReaderFunc(func() {
614 req.Trailer["Client-Trailer-B"] = []string{"valueb"}
615 }),
616 ))
617 req.Trailer = Header{
618 "Client-Trailer-A": nil,
619 "Client-Trailer-B": nil,
620 }
621 req.ContentLength = -1
622 res, err := cst.c.Do(req)
623 if err != nil {
624 t.Fatal(err)
625 }
626 if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
627 t.Error(err)
628 }
629 }
630
631
632 func TestTrailersServerToClient_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, false) }
633 func TestTrailersServerToClient_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, false) }
634 func TestTrailersServerToClient_Flush_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, true) }
635 func TestTrailersServerToClient_Flush_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, true) }
636
637 func testTrailersServerToClient(t *testing.T, h2, flush bool) {
638 defer afterTest(t)
639 const body = "Some body"
640 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
641 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
642 w.Header().Add("Trailer", "Server-Trailer-C")
643
644 io.WriteString(w, body)
645 if flush {
646 w.(Flusher).Flush()
647 }
648
649
650
651
652
653 w.Header().Set("Server-Trailer-A", "valuea")
654 w.Header().Set("Server-Trailer-C", "valuec")
655 w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
656 }))
657 defer cst.close()
658
659 res, err := cst.c.Get(cst.ts.URL)
660 if err != nil {
661 t.Fatal(err)
662 }
663
664 wantHeader := Header{
665 "Content-Type": {"text/plain; charset=utf-8"},
666 }
667 wantLen := -1
668 if h2 && !flush {
669
670
671
672
673
674 wantLen = len(body)
675 wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
676 }
677 if res.ContentLength != int64(wantLen) {
678 t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
679 }
680
681 delete(res.Header, "Date")
682 if !reflect.DeepEqual(res.Header, wantHeader) {
683 t.Errorf("Header = %v; want %v", res.Header, wantHeader)
684 }
685
686 if got, want := res.Trailer, (Header{
687 "Server-Trailer-A": nil,
688 "Server-Trailer-B": nil,
689 "Server-Trailer-C": nil,
690 }); !reflect.DeepEqual(got, want) {
691 t.Errorf("Trailer before body read = %v; want %v", got, want)
692 }
693
694 if err := wantBody(res, nil, body); err != nil {
695 t.Fatal(err)
696 }
697
698 if got, want := res.Trailer, (Header{
699 "Server-Trailer-A": {"valuea"},
700 "Server-Trailer-B": nil,
701 "Server-Trailer-C": {"valuec"},
702 }); !reflect.DeepEqual(got, want) {
703 t.Errorf("Trailer after body read = %v; want %v", got, want)
704 }
705 }
706
707
708 func TestResponseBodyReadAfterClose_h1(t *testing.T) { testResponseBodyReadAfterClose(t, h1Mode) }
709 func TestResponseBodyReadAfterClose_h2(t *testing.T) { testResponseBodyReadAfterClose(t, h2Mode) }
710
711 func testResponseBodyReadAfterClose(t *testing.T, h2 bool) {
712 defer afterTest(t)
713 const body = "Some body"
714 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
715 io.WriteString(w, body)
716 }))
717 defer cst.close()
718 res, err := cst.c.Get(cst.ts.URL)
719 if err != nil {
720 t.Fatal(err)
721 }
722 res.Body.Close()
723 data, err := io.ReadAll(res.Body)
724 if len(data) != 0 || err == nil {
725 t.Fatalf("ReadAll returned %q, %v; want error", data, err)
726 }
727 }
728
729 func TestConcurrentReadWriteReqBody_h1(t *testing.T) { testConcurrentReadWriteReqBody(t, h1Mode) }
730 func TestConcurrentReadWriteReqBody_h2(t *testing.T) { testConcurrentReadWriteReqBody(t, h2Mode) }
731 func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) {
732 defer afterTest(t)
733 const reqBody = "some request body"
734 const resBody = "some response body"
735 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
736 var wg sync.WaitGroup
737 wg.Add(2)
738 didRead := make(chan bool, 1)
739
740 go func() {
741 defer wg.Done()
742 data, err := io.ReadAll(r.Body)
743 if string(data) != reqBody {
744 t.Errorf("Handler read %q; want %q", data, reqBody)
745 }
746 if err != nil {
747 t.Errorf("Handler Read: %v", err)
748 }
749 didRead <- true
750 }()
751
752 go func() {
753 defer wg.Done()
754 if !h2 {
755
756
757
758
759 <-didRead
760 }
761 io.WriteString(w, resBody)
762 }()
763 wg.Wait()
764 }))
765 defer cst.close()
766 req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
767 req.Header.Add("Expect", "100-continue")
768 res, err := cst.c.Do(req)
769 if err != nil {
770 t.Fatal(err)
771 }
772 data, err := io.ReadAll(res.Body)
773 defer res.Body.Close()
774 if err != nil {
775 t.Fatal(err)
776 }
777 if string(data) != resBody {
778 t.Errorf("read %q; want %q", data, resBody)
779 }
780 }
781
782 func TestConnectRequest_h1(t *testing.T) { testConnectRequest(t, h1Mode) }
783 func TestConnectRequest_h2(t *testing.T) { testConnectRequest(t, h2Mode) }
784 func testConnectRequest(t *testing.T, h2 bool) {
785 defer afterTest(t)
786 gotc := make(chan *Request, 1)
787 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
788 gotc <- r
789 }))
790 defer cst.close()
791
792 u, err := url.Parse(cst.ts.URL)
793 if err != nil {
794 t.Fatal(err)
795 }
796
797 tests := []struct {
798 req *Request
799 want string
800 }{
801 {
802 req: &Request{
803 Method: "CONNECT",
804 Header: Header{},
805 URL: u,
806 },
807 want: u.Host,
808 },
809 {
810 req: &Request{
811 Method: "CONNECT",
812 Header: Header{},
813 URL: u,
814 Host: "example.com:123",
815 },
816 want: "example.com:123",
817 },
818 }
819
820 for i, tt := range tests {
821 res, err := cst.c.Do(tt.req)
822 if err != nil {
823 t.Errorf("%d. RoundTrip = %v", i, err)
824 continue
825 }
826 res.Body.Close()
827 req := <-gotc
828 if req.Method != "CONNECT" {
829 t.Errorf("method = %q; want CONNECT", req.Method)
830 }
831 if req.Host != tt.want {
832 t.Errorf("Host = %q; want %q", req.Host, tt.want)
833 }
834 if req.URL.Host != tt.want {
835 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
836 }
837 }
838 }
839
840 func TestTransportUserAgent_h1(t *testing.T) { testTransportUserAgent(t, h1Mode) }
841 func TestTransportUserAgent_h2(t *testing.T) { testTransportUserAgent(t, h2Mode) }
842 func testTransportUserAgent(t *testing.T, h2 bool) {
843 defer afterTest(t)
844 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
845 fmt.Fprintf(w, "%q", r.Header["User-Agent"])
846 }))
847 defer cst.close()
848
849 either := func(a, b string) string {
850 if h2 {
851 return b
852 }
853 return a
854 }
855
856 tests := []struct {
857 setup func(*Request)
858 want string
859 }{
860 {
861 func(r *Request) {},
862 either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
863 },
864 {
865 func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
866 `["foo/1.2.3"]`,
867 },
868 {
869 func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
870 `["single"]`,
871 },
872 {
873 func(r *Request) { r.Header.Set("User-Agent", "") },
874 `[]`,
875 },
876 {
877 func(r *Request) { r.Header["User-Agent"] = nil },
878 `[]`,
879 },
880 }
881 for i, tt := range tests {
882 req, _ := NewRequest("GET", cst.ts.URL, nil)
883 tt.setup(req)
884 res, err := cst.c.Do(req)
885 if err != nil {
886 t.Errorf("%d. RoundTrip = %v", i, err)
887 continue
888 }
889 slurp, err := io.ReadAll(res.Body)
890 res.Body.Close()
891 if err != nil {
892 t.Errorf("%d. read body = %v", i, err)
893 continue
894 }
895 if string(slurp) != tt.want {
896 t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
897 }
898 }
899 }
900
901 func TestStarRequestFoo_h1(t *testing.T) { testStarRequest(t, "FOO", h1Mode) }
902 func TestStarRequestFoo_h2(t *testing.T) { testStarRequest(t, "FOO", h2Mode) }
903 func TestStarRequestOptions_h1(t *testing.T) { testStarRequest(t, "OPTIONS", h1Mode) }
904 func TestStarRequestOptions_h2(t *testing.T) { testStarRequest(t, "OPTIONS", h2Mode) }
905 func testStarRequest(t *testing.T, method string, h2 bool) {
906 defer afterTest(t)
907 gotc := make(chan *Request, 1)
908 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
909 w.Header().Set("foo", "bar")
910 gotc <- r
911 w.(Flusher).Flush()
912 }))
913 defer cst.close()
914
915 u, err := url.Parse(cst.ts.URL)
916 if err != nil {
917 t.Fatal(err)
918 }
919 u.Path = "*"
920
921 req := &Request{
922 Method: method,
923 Header: Header{},
924 URL: u,
925 }
926
927 res, err := cst.c.Do(req)
928 if err != nil {
929 t.Fatalf("RoundTrip = %v", err)
930 }
931 res.Body.Close()
932
933 wantFoo := "bar"
934 wantLen := int64(-1)
935 if method == "OPTIONS" {
936 wantFoo = ""
937 wantLen = 0
938 }
939 if res.StatusCode != 200 {
940 t.Errorf("status code = %v; want %d", res.Status, 200)
941 }
942 if res.ContentLength != wantLen {
943 t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
944 }
945 if got := res.Header.Get("foo"); got != wantFoo {
946 t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
947 }
948 select {
949 case req = <-gotc:
950 default:
951 req = nil
952 }
953 if req == nil {
954 if method != "OPTIONS" {
955 t.Fatalf("handler never got request")
956 }
957 return
958 }
959 if req.Method != method {
960 t.Errorf("method = %q; want %q", req.Method, method)
961 }
962 if req.URL.Path != "*" {
963 t.Errorf("URL.Path = %q; want *", req.URL.Path)
964 }
965 if req.RequestURI != "*" {
966 t.Errorf("RequestURI = %q; want *", req.RequestURI)
967 }
968 }
969
970
971 func TestTransportDiscardsUnneededConns(t *testing.T) {
972 setParallel(t)
973 defer afterTest(t)
974 cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
975 fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
976 }))
977 defer cst.close()
978
979 var numOpen, numClose int32
980
981 tlsConfig := &tls.Config{InsecureSkipVerify: true}
982 tr := &Transport{
983 TLSClientConfig: tlsConfig,
984 DialTLS: func(_, addr string) (net.Conn, error) {
985 time.Sleep(10 * time.Millisecond)
986 rc, err := net.Dial("tcp", addr)
987 if err != nil {
988 return nil, err
989 }
990 atomic.AddInt32(&numOpen, 1)
991 c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
992 return tls.Client(c, tlsConfig), nil
993 },
994 }
995 if err := ExportHttp2ConfigureTransport(tr); err != nil {
996 t.Fatal(err)
997 }
998 defer tr.CloseIdleConnections()
999
1000 c := &Client{Transport: tr}
1001
1002 const N = 10
1003 gotBody := make(chan string, N)
1004 var wg sync.WaitGroup
1005 for i := 0; i < N; i++ {
1006 wg.Add(1)
1007 go func() {
1008 defer wg.Done()
1009 resp, err := c.Get(cst.ts.URL)
1010 if err != nil {
1011
1012
1013 time.Sleep(10 * time.Millisecond)
1014 resp, err = c.Get(cst.ts.URL)
1015 if err != nil {
1016 t.Errorf("Get: %v", err)
1017 return
1018 }
1019 }
1020 defer resp.Body.Close()
1021 slurp, err := io.ReadAll(resp.Body)
1022 if err != nil {
1023 t.Error(err)
1024 }
1025 gotBody <- string(slurp)
1026 }()
1027 }
1028 wg.Wait()
1029 close(gotBody)
1030
1031 var last string
1032 for got := range gotBody {
1033 if last == "" {
1034 last = got
1035 continue
1036 }
1037 if got != last {
1038 t.Errorf("Response body changed: %q -> %q", last, got)
1039 }
1040 }
1041
1042 var open, close int32
1043 for i := 0; i < 150; i++ {
1044 open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
1045 if open < 1 {
1046 t.Fatalf("open = %d; want at least", open)
1047 }
1048 if close == open-1 {
1049
1050 return
1051 }
1052 time.Sleep(10 * time.Millisecond)
1053 }
1054 t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
1055 }
1056
1057
1058 func TestTransportGCRequest_Body_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, true) }
1059 func TestTransportGCRequest_Body_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, true) }
1060 func TestTransportGCRequest_NoBody_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, false) }
1061 func TestTransportGCRequest_NoBody_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, false) }
1062 func testTransportGCRequest(t *testing.T, h2, body bool) {
1063 setParallel(t)
1064 defer afterTest(t)
1065 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1066 io.ReadAll(r.Body)
1067 if body {
1068 io.WriteString(w, "Hello.")
1069 }
1070 }))
1071 defer cst.close()
1072
1073 didGC := make(chan struct{})
1074 (func() {
1075 body := strings.NewReader("some body")
1076 req, _ := NewRequest("POST", cst.ts.URL, body)
1077 runtime.SetFinalizer(req, func(*Request) { close(didGC) })
1078 res, err := cst.c.Do(req)
1079 if err != nil {
1080 t.Fatal(err)
1081 }
1082 if _, err := io.ReadAll(res.Body); err != nil {
1083 t.Fatal(err)
1084 }
1085 if err := res.Body.Close(); err != nil {
1086 t.Fatal(err)
1087 }
1088 })()
1089 timeout := time.NewTimer(5 * time.Second)
1090 defer timeout.Stop()
1091 for {
1092 select {
1093 case <-didGC:
1094 return
1095 case <-time.After(100 * time.Millisecond):
1096 runtime.GC()
1097 case <-timeout.C:
1098 t.Fatal("never saw GC of request")
1099 }
1100 }
1101 }
1102
1103 func TestTransportRejectsInvalidHeaders_h1(t *testing.T) {
1104 testTransportRejectsInvalidHeaders(t, h1Mode)
1105 }
1106 func TestTransportRejectsInvalidHeaders_h2(t *testing.T) {
1107 testTransportRejectsInvalidHeaders(t, h2Mode)
1108 }
1109 func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) {
1110 setParallel(t)
1111 defer afterTest(t)
1112 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1113 fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
1114 }), optQuietLog)
1115 defer cst.close()
1116 cst.tr.DisableKeepAlives = true
1117
1118 tests := []struct {
1119 key, val string
1120 ok bool
1121 }{
1122 {"Foo", "capital-key", true},
1123 {"Foo", "foo\x00bar", false},
1124 {"Foo", "two\nlines", false},
1125 {"bogus\nkey", "v", false},
1126 {"A space", "v", false},
1127 {"имя", "v", false},
1128 {"name", "валю", true},
1129 {"", "v", false},
1130 {"k", "", true},
1131 }
1132 for _, tt := range tests {
1133 dialedc := make(chan bool, 1)
1134 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
1135 dialedc <- true
1136 return net.Dial(netw, addr)
1137 }
1138 req, _ := NewRequest("GET", cst.ts.URL, nil)
1139 req.Header[tt.key] = []string{tt.val}
1140 res, err := cst.c.Do(req)
1141 var body []byte
1142 if err == nil {
1143 body, _ = io.ReadAll(res.Body)
1144 res.Body.Close()
1145 }
1146 var dialed bool
1147 select {
1148 case <-dialedc:
1149 dialed = true
1150 default:
1151 }
1152
1153 if !tt.ok && dialed {
1154 t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
1155 } else if (err == nil) != tt.ok {
1156 t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
1157 }
1158 }
1159 }
1160
1161 func TestInterruptWithPanic_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, "boom") }
1162 func TestInterruptWithPanic_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, "boom") }
1163 func TestInterruptWithPanic_nil_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, nil) }
1164 func TestInterruptWithPanic_nil_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, nil) }
1165 func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) {
1166 testInterruptWithPanic(t, h1Mode, ErrAbortHandler)
1167 }
1168 func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) {
1169 testInterruptWithPanic(t, h2Mode, ErrAbortHandler)
1170 }
1171 func testInterruptWithPanic(t *testing.T, h2 bool, panicValue any) {
1172 setParallel(t)
1173 const msg = "hello"
1174 defer afterTest(t)
1175
1176 testDone := make(chan struct{})
1177 defer close(testDone)
1178
1179 var errorLog lockedBytesBuffer
1180 gotHeaders := make(chan bool, 1)
1181 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1182 io.WriteString(w, msg)
1183 w.(Flusher).Flush()
1184
1185 select {
1186 case <-gotHeaders:
1187 case <-testDone:
1188 }
1189 panic(panicValue)
1190 }), func(ts *httptest.Server) {
1191 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1192 })
1193 defer cst.close()
1194 res, err := cst.c.Get(cst.ts.URL)
1195 if err != nil {
1196 t.Fatal(err)
1197 }
1198 gotHeaders <- true
1199 defer res.Body.Close()
1200 slurp, err := io.ReadAll(res.Body)
1201 if string(slurp) != msg {
1202 t.Errorf("client read %q; want %q", slurp, msg)
1203 }
1204 if err == nil {
1205 t.Errorf("client read all successfully; want some error")
1206 }
1207 logOutput := func() string {
1208 errorLog.Lock()
1209 defer errorLog.Unlock()
1210 return errorLog.String()
1211 }
1212 wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
1213
1214 if err := waitErrCondition(5*time.Second, 10*time.Millisecond, func() error {
1215 gotLog := logOutput()
1216 if !wantStackLogged {
1217 if gotLog == "" {
1218 return nil
1219 }
1220 return fmt.Errorf("want no log output; got: %s", gotLog)
1221 }
1222 if gotLog == "" {
1223 return fmt.Errorf("wanted a stack trace logged; got nothing")
1224 }
1225 if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
1226 return fmt.Errorf("output doesn't look like a panic stack trace. Got: %s", gotLog)
1227 }
1228 return nil
1229 }); err != nil {
1230 t.Fatal(err)
1231 }
1232 }
1233
1234 type lockedBytesBuffer struct {
1235 sync.Mutex
1236 bytes.Buffer
1237 }
1238
1239 func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
1240 b.Lock()
1241 defer b.Unlock()
1242 return b.Buffer.Write(p)
1243 }
1244
1245
1246 func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
1247 h12Compare{
1248 Handler: func(w ResponseWriter, r *Request) {
1249 h := w.Header()
1250 h.Set("Content-Encoding", "gzip")
1251 h.Set("Content-Length", "23")
1252 io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
1253 },
1254 EarlyCheckResponse: func(proto string, res *Response) {
1255 if !res.Uncompressed {
1256 t.Errorf("%s: expected Uncompressed to be set", proto)
1257 }
1258 dump, err := httputil.DumpResponse(res, true)
1259 if err != nil {
1260 t.Errorf("%s: DumpResponse: %v", proto, err)
1261 return
1262 }
1263 if strings.Contains(string(dump), "Connection: close") {
1264 t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
1265 }
1266 if !strings.Contains(string(dump), "FOO") {
1267 t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
1268 }
1269 },
1270 }.run(t)
1271 }
1272
1273
1274 func TestCloseIdleConnections_h1(t *testing.T) { testCloseIdleConnections(t, h1Mode) }
1275 func TestCloseIdleConnections_h2(t *testing.T) { testCloseIdleConnections(t, h2Mode) }
1276 func testCloseIdleConnections(t *testing.T, h2 bool) {
1277 setParallel(t)
1278 defer afterTest(t)
1279 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1280 w.Header().Set("X-Addr", r.RemoteAddr)
1281 }))
1282 defer cst.close()
1283 get := func() string {
1284 res, err := cst.c.Get(cst.ts.URL)
1285 if err != nil {
1286 t.Fatal(err)
1287 }
1288 res.Body.Close()
1289 v := res.Header.Get("X-Addr")
1290 if v == "" {
1291 t.Fatal("didn't get X-Addr")
1292 }
1293 return v
1294 }
1295 a1 := get()
1296 cst.tr.CloseIdleConnections()
1297 a2 := get()
1298 if a1 == a2 {
1299 t.Errorf("didn't close connection")
1300 }
1301 }
1302
1303 type noteCloseConn struct {
1304 net.Conn
1305 closeFunc func()
1306 }
1307
1308 func (x noteCloseConn) Close() error {
1309 x.closeFunc()
1310 return x.Conn.Close()
1311 }
1312
1313 type testErrorReader struct{ t *testing.T }
1314
1315 func (r testErrorReader) Read(p []byte) (n int, err error) {
1316 r.t.Error("unexpected Read call")
1317 return 0, io.EOF
1318 }
1319
1320 func TestNoSniffExpectRequestBody_h1(t *testing.T) { testNoSniffExpectRequestBody(t, h1Mode) }
1321 func TestNoSniffExpectRequestBody_h2(t *testing.T) { testNoSniffExpectRequestBody(t, h2Mode) }
1322
1323 func testNoSniffExpectRequestBody(t *testing.T, h2 bool) {
1324 defer afterTest(t)
1325 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1326 w.WriteHeader(StatusUnauthorized)
1327 }))
1328 defer cst.close()
1329
1330
1331 cst.tr.ExpectContinueTimeout = 10 * time.Second
1332
1333 req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
1334 if err != nil {
1335 t.Fatal(err)
1336 }
1337 req.ContentLength = 0
1338 req.Header.Set("Expect", "100-continue")
1339 res, err := cst.tr.RoundTrip(req)
1340 if err != nil {
1341 t.Fatal(err)
1342 }
1343 defer res.Body.Close()
1344 if res.StatusCode != StatusUnauthorized {
1345 t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
1346 }
1347 }
1348
1349 func TestServerUndeclaredTrailers_h1(t *testing.T) { testServerUndeclaredTrailers(t, h1Mode) }
1350 func TestServerUndeclaredTrailers_h2(t *testing.T) { testServerUndeclaredTrailers(t, h2Mode) }
1351 func testServerUndeclaredTrailers(t *testing.T, h2 bool) {
1352 defer afterTest(t)
1353 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1354 w.Header().Set("Foo", "Bar")
1355 w.Header().Set("Trailer:Foo", "Baz")
1356 w.(Flusher).Flush()
1357 w.Header().Add("Trailer:Foo", "Baz2")
1358 w.Header().Set("Trailer:Bar", "Quux")
1359 }))
1360 defer cst.close()
1361 res, err := cst.c.Get(cst.ts.URL)
1362 if err != nil {
1363 t.Fatal(err)
1364 }
1365 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1366 t.Fatal(err)
1367 }
1368 res.Body.Close()
1369 delete(res.Header, "Date")
1370 delete(res.Header, "Content-Type")
1371
1372 if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
1373 t.Errorf("Header = %#v; want %#v", res.Header, want)
1374 }
1375 if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
1376 t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
1377 }
1378 }
1379
1380 func TestBadResponseAfterReadingBody(t *testing.T) {
1381 defer afterTest(t)
1382 cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) {
1383 _, err := io.Copy(io.Discard, r.Body)
1384 if err != nil {
1385 t.Fatal(err)
1386 }
1387 c, _, err := w.(Hijacker).Hijack()
1388 if err != nil {
1389 t.Fatal(err)
1390 }
1391 defer c.Close()
1392 fmt.Fprintln(c, "some bogus crap")
1393 }))
1394 defer cst.close()
1395
1396 closes := 0
1397 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
1398 if err == nil {
1399 res.Body.Close()
1400 t.Fatal("expected an error to be returned from Post")
1401 }
1402 if closes != 1 {
1403 t.Errorf("closes = %d; want 1", closes)
1404 }
1405 }
1406
1407 func TestWriteHeader0_h1(t *testing.T) { testWriteHeader0(t, h1Mode) }
1408 func TestWriteHeader0_h2(t *testing.T) { testWriteHeader0(t, h2Mode) }
1409 func testWriteHeader0(t *testing.T, h2 bool) {
1410 defer afterTest(t)
1411 gotpanic := make(chan bool, 1)
1412 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1413 defer close(gotpanic)
1414 defer func() {
1415 if e := recover(); e != nil {
1416 got := fmt.Sprintf("%T, %v", e, e)
1417 want := "string, invalid WriteHeader code 0"
1418 if got != want {
1419 t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
1420 }
1421 gotpanic <- true
1422
1423
1424
1425
1426 w.WriteHeader(503)
1427 }
1428 }()
1429 w.WriteHeader(0)
1430 }))
1431 defer cst.close()
1432 res, err := cst.c.Get(cst.ts.URL)
1433 if err != nil {
1434 t.Fatal(err)
1435 }
1436 if res.StatusCode != 503 {
1437 t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
1438 }
1439 if !<-gotpanic {
1440 t.Error("expected panic in handler")
1441 }
1442 }
1443
1444
1445
1446 func TestWriteHeaderNoCodeCheck_h1(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, false) }
1447 func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, true) }
1448 func TestWriteHeaderNoCodeCheck_h2(t *testing.T) { testWriteHeaderAfterWrite(t, h2Mode, false) }
1449 func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) {
1450 setParallel(t)
1451 defer afterTest(t)
1452
1453 var errorLog lockedBytesBuffer
1454 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1455 if hijack {
1456 conn, _, _ := w.(Hijacker).Hijack()
1457 defer conn.Close()
1458 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
1459 w.WriteHeader(0)
1460 conn.Write([]byte("bar"))
1461 return
1462 }
1463 io.WriteString(w, "foo")
1464 w.(Flusher).Flush()
1465 w.WriteHeader(0)
1466 io.WriteString(w, "bar")
1467 }), func(ts *httptest.Server) {
1468 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1469 })
1470 defer cst.close()
1471 res, err := cst.c.Get(cst.ts.URL)
1472 if err != nil {
1473 t.Fatal(err)
1474 }
1475 defer res.Body.Close()
1476 body, err := io.ReadAll(res.Body)
1477 if err != nil {
1478 t.Fatal(err)
1479 }
1480 if got, want := string(body), "foobar"; got != want {
1481 t.Errorf("got = %q; want %q", got, want)
1482 }
1483
1484
1485 if h2 {
1486
1487
1488 return
1489 }
1490 gotLog := strings.TrimSpace(errorLog.String())
1491 wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1492 if hijack {
1493 wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1494 }
1495 if !strings.HasPrefix(gotLog, wantLog) {
1496 t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
1497 }
1498 }
1499
1500 func TestBidiStreamReverseProxy(t *testing.T) {
1501 setParallel(t)
1502 defer afterTest(t)
1503 backend := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1504 if _, err := io.Copy(w, r.Body); err != nil {
1505 log.Printf("bidi backend copy: %v", err)
1506 }
1507 }))
1508 defer backend.close()
1509
1510 backURL, err := url.Parse(backend.ts.URL)
1511 if err != nil {
1512 t.Fatal(err)
1513 }
1514 rp := httputil.NewSingleHostReverseProxy(backURL)
1515 rp.Transport = backend.tr
1516 proxy := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1517 rp.ServeHTTP(w, r)
1518 }))
1519 defer proxy.close()
1520
1521 bodyRes := make(chan any, 1)
1522 pr, pw := io.Pipe()
1523 req, _ := NewRequest("PUT", proxy.ts.URL, pr)
1524 const size = 4 << 20
1525 go func() {
1526 h := sha1.New()
1527 _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
1528 go pw.Close()
1529 if err != nil {
1530 bodyRes <- err
1531 } else {
1532 bodyRes <- h
1533 }
1534 }()
1535 res, err := backend.c.Do(req)
1536 if err != nil {
1537 t.Fatal(err)
1538 }
1539 defer res.Body.Close()
1540 hgot := sha1.New()
1541 n, err := io.Copy(hgot, res.Body)
1542 if err != nil {
1543 t.Fatal(err)
1544 }
1545 if n != size {
1546 t.Fatalf("got %d bytes; want %d", n, size)
1547 }
1548 select {
1549 case v := <-bodyRes:
1550 switch v := v.(type) {
1551 default:
1552 t.Fatalf("body copy: %v", err)
1553 case hash.Hash:
1554 if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
1555 t.Errorf("written bytes didn't match received bytes")
1556 }
1557 }
1558 case <-time.After(10 * time.Second):
1559 t.Fatal("timeout")
1560 }
1561
1562 }
1563
1564
1565 func TestH12_WebSocketUpgrade(t *testing.T) {
1566 h12Compare{
1567 Handler: func(w ResponseWriter, r *Request) {
1568 h := w.Header()
1569 h.Set("Foo", "bar")
1570 },
1571 ReqFunc: func(c *Client, url string) (*Response, error) {
1572 req, _ := NewRequest("GET", url, nil)
1573 req.Header.Set("Connection", "Upgrade")
1574 req.Header.Set("Upgrade", "WebSocket")
1575 return c.Do(req)
1576 },
1577 EarlyCheckResponse: func(proto string, res *Response) {
1578 if res.Proto != "HTTP/1.1" {
1579 t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
1580 }
1581 res.Proto = "HTTP/IGNORE"
1582 },
1583 }.run(t)
1584 }
1585
1586 func TestIdentityTransferEncoding_h1(t *testing.T) { testIdentityTransferEncoding(t, h1Mode) }
1587 func TestIdentityTransferEncoding_h2(t *testing.T) { testIdentityTransferEncoding(t, h2Mode) }
1588
1589 func testIdentityTransferEncoding(t *testing.T, h2 bool) {
1590 setParallel(t)
1591 defer afterTest(t)
1592
1593 const body = "body"
1594 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1595 gotBody, _ := io.ReadAll(r.Body)
1596 if got, want := string(gotBody), body; got != want {
1597 t.Errorf("got request body = %q; want %q", got, want)
1598 }
1599 w.Header().Set("Transfer-Encoding", "identity")
1600 w.WriteHeader(StatusOK)
1601 w.(Flusher).Flush()
1602 io.WriteString(w, body)
1603 }))
1604 defer cst.close()
1605 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
1606 res, err := cst.c.Do(req)
1607 if err != nil {
1608 t.Fatal(err)
1609 }
1610 defer res.Body.Close()
1611 gotBody, err := io.ReadAll(res.Body)
1612 if err != nil {
1613 t.Fatal(err)
1614 }
1615 if got, want := string(gotBody), body; got != want {
1616 t.Errorf("got response body = %q; want %q", got, want)
1617 }
1618 }
1619
View as plain text