1
2
3
4
5
6
7 package httputil
8
9 import (
10 "bufio"
11 "bytes"
12 "context"
13 "errors"
14 "fmt"
15 "io"
16 "log"
17 "net/http"
18 "net/http/httptest"
19 "net/http/internal/ascii"
20 "net/url"
21 "os"
22 "reflect"
23 "sort"
24 "strconv"
25 "strings"
26 "sync"
27 "testing"
28 "time"
29 )
30
31 const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
32
33 func init() {
34 inOurTests = true
35 hopHeaders = append(hopHeaders, fakeHopHeader)
36 }
37
38 func TestReverseProxy(t *testing.T) {
39 const backendResponse = "I am the backend"
40 const backendStatus = 404
41 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
42 if r.Method == "GET" && r.FormValue("mode") == "hangup" {
43 c, _, _ := w.(http.Hijacker).Hijack()
44 c.Close()
45 return
46 }
47 if len(r.TransferEncoding) > 0 {
48 t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
49 }
50 if r.Header.Get("X-Forwarded-For") == "" {
51 t.Errorf("didn't get X-Forwarded-For header")
52 }
53 if c := r.Header.Get("Connection"); c != "" {
54 t.Errorf("handler got Connection header value %q", c)
55 }
56 if c := r.Header.Get("Te"); c != "trailers" {
57 t.Errorf("handler got Te header value %q; want 'trailers'", c)
58 }
59 if c := r.Header.Get("Upgrade"); c != "" {
60 t.Errorf("handler got Upgrade header value %q", c)
61 }
62 if c := r.Header.Get("Proxy-Connection"); c != "" {
63 t.Errorf("handler got Proxy-Connection header value %q", c)
64 }
65 if g, e := r.Host, "some-name"; g != e {
66 t.Errorf("backend got Host header %q, want %q", g, e)
67 }
68 w.Header().Set("Trailers", "not a special header field name")
69 w.Header().Set("Trailer", "X-Trailer")
70 w.Header().Set("X-Foo", "bar")
71 w.Header().Set("Upgrade", "foo")
72 w.Header().Set(fakeHopHeader, "foo")
73 w.Header().Add("X-Multi-Value", "foo")
74 w.Header().Add("X-Multi-Value", "bar")
75 http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
76 w.WriteHeader(backendStatus)
77 w.Write([]byte(backendResponse))
78 w.Header().Set("X-Trailer", "trailer_value")
79 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
80 }))
81 defer backend.Close()
82 backendURL, err := url.Parse(backend.URL)
83 if err != nil {
84 t.Fatal(err)
85 }
86 proxyHandler := NewSingleHostReverseProxy(backendURL)
87 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
88 frontend := httptest.NewServer(proxyHandler)
89 defer frontend.Close()
90 frontendClient := frontend.Client()
91
92 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
93 getReq.Host = "some-name"
94 getReq.Header.Set("Connection", "close, TE")
95 getReq.Header.Add("Te", "foo")
96 getReq.Header.Add("Te", "bar, trailers")
97 getReq.Header.Set("Proxy-Connection", "should be deleted")
98 getReq.Header.Set("Upgrade", "foo")
99 getReq.Close = true
100 res, err := frontendClient.Do(getReq)
101 if err != nil {
102 t.Fatalf("Get: %v", err)
103 }
104 if g, e := res.StatusCode, backendStatus; g != e {
105 t.Errorf("got res.StatusCode %d; expected %d", g, e)
106 }
107 if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
108 t.Errorf("got X-Foo %q; expected %q", g, e)
109 }
110 if c := res.Header.Get(fakeHopHeader); c != "" {
111 t.Errorf("got %s header value %q", fakeHopHeader, c)
112 }
113 if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
114 t.Errorf("header Trailers = %q; want %q", g, e)
115 }
116 if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
117 t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
118 }
119 if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
120 t.Fatalf("got %d SetCookies, want %d", g, e)
121 }
122 if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
123 t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
124 }
125 if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
126 t.Errorf("unexpected cookie %q", cookie.Name)
127 }
128 bodyBytes, _ := io.ReadAll(res.Body)
129 if g, e := string(bodyBytes), backendResponse; g != e {
130 t.Errorf("got body %q; expected %q", g, e)
131 }
132 if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
133 t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
134 }
135 if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
136 t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
137 }
138
139
140
141 getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
142 getReq.Close = true
143 res, err = frontendClient.Do(getReq)
144 if err != nil {
145 t.Fatal(err)
146 }
147 res.Body.Close()
148 if res.StatusCode != http.StatusBadGateway {
149 t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
150 }
151
152 }
153
154
155
156 func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
157 const fakeConnectionToken = "X-Fake-Connection-Token"
158 const backendResponse = "I am the backend"
159
160
161
162 const someConnHeader = "X-Some-Conn-Header"
163
164 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
165 if c := r.Header.Get("Connection"); c != "" {
166 t.Errorf("handler got header %q = %q; want empty", "Connection", c)
167 }
168 if c := r.Header.Get(fakeConnectionToken); c != "" {
169 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
170 }
171 if c := r.Header.Get(someConnHeader); c != "" {
172 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
173 }
174 w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
175 w.Header().Add("Connection", someConnHeader)
176 w.Header().Set(someConnHeader, "should be deleted")
177 w.Header().Set(fakeConnectionToken, "should be deleted")
178 io.WriteString(w, backendResponse)
179 }))
180 defer backend.Close()
181 backendURL, err := url.Parse(backend.URL)
182 if err != nil {
183 t.Fatal(err)
184 }
185 proxyHandler := NewSingleHostReverseProxy(backendURL)
186 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
187 proxyHandler.ServeHTTP(w, r)
188 if c := r.Header.Get(someConnHeader); c != "should be deleted" {
189 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
190 }
191 if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
192 t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
193 }
194 c := r.Header["Connection"]
195 var cf []string
196 for _, f := range c {
197 for _, sf := range strings.Split(f, ",") {
198 if sf = strings.TrimSpace(sf); sf != "" {
199 cf = append(cf, sf)
200 }
201 }
202 }
203 sort.Strings(cf)
204 expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
205 sort.Strings(expectedValues)
206 if !reflect.DeepEqual(cf, expectedValues) {
207 t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
208 }
209 }))
210 defer frontend.Close()
211
212 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
213 getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
214 getReq.Header.Add("Connection", someConnHeader)
215 getReq.Header.Set(someConnHeader, "should be deleted")
216 getReq.Header.Set(fakeConnectionToken, "should be deleted")
217 res, err := frontend.Client().Do(getReq)
218 if err != nil {
219 t.Fatalf("Get: %v", err)
220 }
221 defer res.Body.Close()
222 bodyBytes, err := io.ReadAll(res.Body)
223 if err != nil {
224 t.Fatalf("reading body: %v", err)
225 }
226 if got, want := string(bodyBytes), backendResponse; got != want {
227 t.Errorf("got body %q; want %q", got, want)
228 }
229 if c := res.Header.Get("Connection"); c != "" {
230 t.Errorf("handler got header %q = %q; want empty", "Connection", c)
231 }
232 if c := res.Header.Get(someConnHeader); c != "" {
233 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
234 }
235 if c := res.Header.Get(fakeConnectionToken); c != "" {
236 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
237 }
238 }
239
240 func TestReverseProxyStripEmptyConnection(t *testing.T) {
241
242 const backendResponse = "I am the backend"
243
244
245
246 const someConnHeader = "X-Some-Conn-Header"
247
248 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
249 if c := r.Header.Values("Connection"); len(c) != 0 {
250 t.Errorf("handler got header %q = %v; want empty", "Connection", c)
251 }
252 if c := r.Header.Get(someConnHeader); c != "" {
253 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
254 }
255 w.Header().Add("Connection", "")
256 w.Header().Add("Connection", someConnHeader)
257 w.Header().Set(someConnHeader, "should be deleted")
258 io.WriteString(w, backendResponse)
259 }))
260 defer backend.Close()
261 backendURL, err := url.Parse(backend.URL)
262 if err != nil {
263 t.Fatal(err)
264 }
265 proxyHandler := NewSingleHostReverseProxy(backendURL)
266 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
267 proxyHandler.ServeHTTP(w, r)
268 if c := r.Header.Get(someConnHeader); c != "should be deleted" {
269 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
270 }
271 }))
272 defer frontend.Close()
273
274 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
275 getReq.Header.Add("Connection", "")
276 getReq.Header.Add("Connection", someConnHeader)
277 getReq.Header.Set(someConnHeader, "should be deleted")
278 res, err := frontend.Client().Do(getReq)
279 if err != nil {
280 t.Fatalf("Get: %v", err)
281 }
282 defer res.Body.Close()
283 bodyBytes, err := io.ReadAll(res.Body)
284 if err != nil {
285 t.Fatalf("reading body: %v", err)
286 }
287 if got, want := string(bodyBytes), backendResponse; got != want {
288 t.Errorf("got body %q; want %q", got, want)
289 }
290 if c := res.Header.Get("Connection"); c != "" {
291 t.Errorf("handler got header %q = %q; want empty", "Connection", c)
292 }
293 if c := res.Header.Get(someConnHeader); c != "" {
294 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
295 }
296 }
297
298 func TestXForwardedFor(t *testing.T) {
299 const prevForwardedFor = "client ip"
300 const backendResponse = "I am the backend"
301 const backendStatus = 404
302 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
303 if r.Header.Get("X-Forwarded-For") == "" {
304 t.Errorf("didn't get X-Forwarded-For header")
305 }
306 if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
307 t.Errorf("X-Forwarded-For didn't contain prior data")
308 }
309 w.WriteHeader(backendStatus)
310 w.Write([]byte(backendResponse))
311 }))
312 defer backend.Close()
313 backendURL, err := url.Parse(backend.URL)
314 if err != nil {
315 t.Fatal(err)
316 }
317 proxyHandler := NewSingleHostReverseProxy(backendURL)
318 frontend := httptest.NewServer(proxyHandler)
319 defer frontend.Close()
320
321 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
322 getReq.Host = "some-name"
323 getReq.Header.Set("Connection", "close")
324 getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
325 getReq.Close = true
326 res, err := frontend.Client().Do(getReq)
327 if err != nil {
328 t.Fatalf("Get: %v", err)
329 }
330 if g, e := res.StatusCode, backendStatus; g != e {
331 t.Errorf("got res.StatusCode %d; expected %d", g, e)
332 }
333 bodyBytes, _ := io.ReadAll(res.Body)
334 if g, e := string(bodyBytes), backendResponse; g != e {
335 t.Errorf("got body %q; expected %q", g, e)
336 }
337 }
338
339
340 func TestXForwardedFor_Omit(t *testing.T) {
341 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
342 if v := r.Header.Get("X-Forwarded-For"); v != "" {
343 t.Errorf("got X-Forwarded-For header: %q", v)
344 }
345 w.Write([]byte("hi"))
346 }))
347 defer backend.Close()
348 backendURL, err := url.Parse(backend.URL)
349 if err != nil {
350 t.Fatal(err)
351 }
352 proxyHandler := NewSingleHostReverseProxy(backendURL)
353 frontend := httptest.NewServer(proxyHandler)
354 defer frontend.Close()
355
356 oldDirector := proxyHandler.Director
357 proxyHandler.Director = func(r *http.Request) {
358 r.Header["X-Forwarded-For"] = nil
359 oldDirector(r)
360 }
361
362 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
363 getReq.Host = "some-name"
364 getReq.Close = true
365 res, err := frontend.Client().Do(getReq)
366 if err != nil {
367 t.Fatalf("Get: %v", err)
368 }
369 res.Body.Close()
370 }
371
372 var proxyQueryTests = []struct {
373 baseSuffix string
374 reqSuffix string
375 want string
376 }{
377 {"", "", ""},
378 {"?sta=tic", "?us=er", "sta=tic&us=er"},
379 {"", "?us=er", "us=er"},
380 {"?sta=tic", "", "sta=tic"},
381 }
382
383 func TestReverseProxyQuery(t *testing.T) {
384 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
385 w.Header().Set("X-Got-Query", r.URL.RawQuery)
386 w.Write([]byte("hi"))
387 }))
388 defer backend.Close()
389
390 for i, tt := range proxyQueryTests {
391 backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
392 if err != nil {
393 t.Fatal(err)
394 }
395 frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
396 req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
397 req.Close = true
398 res, err := frontend.Client().Do(req)
399 if err != nil {
400 t.Fatalf("%d. Get: %v", i, err)
401 }
402 if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
403 t.Errorf("%d. got query %q; expected %q", i, g, e)
404 }
405 res.Body.Close()
406 frontend.Close()
407 }
408 }
409
410 func TestReverseProxyFlushInterval(t *testing.T) {
411 const expected = "hi"
412 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
413 w.Write([]byte(expected))
414 }))
415 defer backend.Close()
416
417 backendURL, err := url.Parse(backend.URL)
418 if err != nil {
419 t.Fatal(err)
420 }
421
422 proxyHandler := NewSingleHostReverseProxy(backendURL)
423 proxyHandler.FlushInterval = time.Microsecond
424
425 frontend := httptest.NewServer(proxyHandler)
426 defer frontend.Close()
427
428 req, _ := http.NewRequest("GET", frontend.URL, nil)
429 req.Close = true
430 res, err := frontend.Client().Do(req)
431 if err != nil {
432 t.Fatalf("Get: %v", err)
433 }
434 defer res.Body.Close()
435 if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
436 t.Errorf("got body %q; expected %q", bodyBytes, expected)
437 }
438 }
439
440 func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
441 const expected = "hi"
442 stopCh := make(chan struct{})
443 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
444 w.Header().Add("MyHeader", expected)
445 w.WriteHeader(200)
446 w.(http.Flusher).Flush()
447 <-stopCh
448 }))
449 defer backend.Close()
450 defer close(stopCh)
451
452 backendURL, err := url.Parse(backend.URL)
453 if err != nil {
454 t.Fatal(err)
455 }
456
457 proxyHandler := NewSingleHostReverseProxy(backendURL)
458 proxyHandler.FlushInterval = time.Microsecond
459
460 frontend := httptest.NewServer(proxyHandler)
461 defer frontend.Close()
462
463 req, _ := http.NewRequest("GET", frontend.URL, nil)
464 req.Close = true
465
466 ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
467 defer cancel()
468 req = req.WithContext(ctx)
469
470 res, err := frontend.Client().Do(req)
471 if err != nil {
472 t.Fatalf("Get: %v", err)
473 }
474 defer res.Body.Close()
475
476 if res.Header.Get("MyHeader") != expected {
477 t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
478 }
479 }
480
481 func TestReverseProxyCancellation(t *testing.T) {
482 const backendResponse = "I am the backend"
483
484 reqInFlight := make(chan struct{})
485 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
486 close(reqInFlight)
487
488 select {
489 case <-time.After(10 * time.Second):
490
491
492 t.Error("Handler never saw CloseNotify")
493 return
494 case <-w.(http.CloseNotifier).CloseNotify():
495 }
496
497 w.WriteHeader(http.StatusOK)
498 w.Write([]byte(backendResponse))
499 }))
500
501 defer backend.Close()
502
503 backend.Config.ErrorLog = log.New(io.Discard, "", 0)
504
505 backendURL, err := url.Parse(backend.URL)
506 if err != nil {
507 t.Fatal(err)
508 }
509
510 proxyHandler := NewSingleHostReverseProxy(backendURL)
511
512
513
514 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
515
516 frontend := httptest.NewServer(proxyHandler)
517 defer frontend.Close()
518 frontendClient := frontend.Client()
519
520 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
521 go func() {
522 <-reqInFlight
523 frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
524 }()
525 res, err := frontendClient.Do(getReq)
526 if res != nil {
527 t.Errorf("got response %v; want nil", res.Status)
528 }
529 if err == nil {
530
531
532
533 t.Error("Server.Client().Do() returned nil error; want non-nil error")
534 }
535 }
536
537 func req(t *testing.T, v string) *http.Request {
538 req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
539 if err != nil {
540 t.Fatal(err)
541 }
542 return req
543 }
544
545
546 func TestNilBody(t *testing.T) {
547 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
548 w.Write([]byte("hi"))
549 }))
550 defer backend.Close()
551
552 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
553 backURL, _ := url.Parse(backend.URL)
554 rp := NewSingleHostReverseProxy(backURL)
555 r := req(t, "GET / HTTP/1.0\r\n\r\n")
556 r.Body = nil
557 rp.ServeHTTP(w, r)
558 }))
559 defer frontend.Close()
560
561 res, err := http.Get(frontend.URL)
562 if err != nil {
563 t.Fatal(err)
564 }
565 defer res.Body.Close()
566 slurp, err := io.ReadAll(res.Body)
567 if err != nil {
568 t.Fatal(err)
569 }
570 if string(slurp) != "hi" {
571 t.Errorf("Got %q; want %q", slurp, "hi")
572 }
573 }
574
575
576 func TestUserAgentHeader(t *testing.T) {
577 const explicitUA = "explicit UA"
578 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
579 if r.URL.Path == "/noua" {
580 if c := r.Header.Get("User-Agent"); c != "" {
581 t.Errorf("handler got non-empty User-Agent header %q", c)
582 }
583 return
584 }
585 if c := r.Header.Get("User-Agent"); c != explicitUA {
586 t.Errorf("handler got unexpected User-Agent header %q", c)
587 }
588 }))
589 defer backend.Close()
590 backendURL, err := url.Parse(backend.URL)
591 if err != nil {
592 t.Fatal(err)
593 }
594 proxyHandler := NewSingleHostReverseProxy(backendURL)
595 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
596 frontend := httptest.NewServer(proxyHandler)
597 defer frontend.Close()
598 frontendClient := frontend.Client()
599
600 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
601 getReq.Header.Set("User-Agent", explicitUA)
602 getReq.Close = true
603 res, err := frontendClient.Do(getReq)
604 if err != nil {
605 t.Fatalf("Get: %v", err)
606 }
607 res.Body.Close()
608
609 getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil)
610 getReq.Header.Set("User-Agent", "")
611 getReq.Close = true
612 res, err = frontendClient.Do(getReq)
613 if err != nil {
614 t.Fatalf("Get: %v", err)
615 }
616 res.Body.Close()
617 }
618
619 type bufferPool struct {
620 get func() []byte
621 put func([]byte)
622 }
623
624 func (bp bufferPool) Get() []byte { return bp.get() }
625 func (bp bufferPool) Put(v []byte) { bp.put(v) }
626
627 func TestReverseProxyGetPutBuffer(t *testing.T) {
628 const msg = "hi"
629 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
630 io.WriteString(w, msg)
631 }))
632 defer backend.Close()
633
634 backendURL, err := url.Parse(backend.URL)
635 if err != nil {
636 t.Fatal(err)
637 }
638
639 var (
640 mu sync.Mutex
641 log []string
642 )
643 addLog := func(event string) {
644 mu.Lock()
645 defer mu.Unlock()
646 log = append(log, event)
647 }
648 rp := NewSingleHostReverseProxy(backendURL)
649 const size = 1234
650 rp.BufferPool = bufferPool{
651 get: func() []byte {
652 addLog("getBuf")
653 return make([]byte, size)
654 },
655 put: func(p []byte) {
656 addLog("putBuf-" + strconv.Itoa(len(p)))
657 },
658 }
659 frontend := httptest.NewServer(rp)
660 defer frontend.Close()
661
662 req, _ := http.NewRequest("GET", frontend.URL, nil)
663 req.Close = true
664 res, err := frontend.Client().Do(req)
665 if err != nil {
666 t.Fatalf("Get: %v", err)
667 }
668 slurp, err := io.ReadAll(res.Body)
669 res.Body.Close()
670 if err != nil {
671 t.Fatalf("reading body: %v", err)
672 }
673 if string(slurp) != msg {
674 t.Errorf("msg = %q; want %q", slurp, msg)
675 }
676 wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
677 mu.Lock()
678 defer mu.Unlock()
679 if !reflect.DeepEqual(log, wantLog) {
680 t.Errorf("Log events = %q; want %q", log, wantLog)
681 }
682 }
683
684 func TestReverseProxy_Post(t *testing.T) {
685 const backendResponse = "I am the backend"
686 const backendStatus = 200
687 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
688 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
689 slurp, err := io.ReadAll(r.Body)
690 if err != nil {
691 t.Errorf("Backend body read = %v", err)
692 }
693 if len(slurp) != len(requestBody) {
694 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
695 }
696 if !bytes.Equal(slurp, requestBody) {
697 t.Error("Backend read wrong request body.")
698 }
699 w.Write([]byte(backendResponse))
700 }))
701 defer backend.Close()
702 backendURL, err := url.Parse(backend.URL)
703 if err != nil {
704 t.Fatal(err)
705 }
706 proxyHandler := NewSingleHostReverseProxy(backendURL)
707 frontend := httptest.NewServer(proxyHandler)
708 defer frontend.Close()
709
710 postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
711 res, err := frontend.Client().Do(postReq)
712 if err != nil {
713 t.Fatalf("Do: %v", err)
714 }
715 if g, e := res.StatusCode, backendStatus; g != e {
716 t.Errorf("got res.StatusCode %d; expected %d", g, e)
717 }
718 bodyBytes, _ := io.ReadAll(res.Body)
719 if g, e := string(bodyBytes), backendResponse; g != e {
720 t.Errorf("got body %q; expected %q", g, e)
721 }
722 }
723
724 type RoundTripperFunc func(*http.Request) (*http.Response, error)
725
726 func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
727 return fn(req)
728 }
729
730
731 func TestReverseProxy_NilBody(t *testing.T) {
732 backendURL, _ := url.Parse("http://fake.tld/")
733 proxyHandler := NewSingleHostReverseProxy(backendURL)
734 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
735 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
736 if req.Body != nil {
737 t.Error("Body != nil; want a nil Body")
738 }
739 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
740 })
741 frontend := httptest.NewServer(proxyHandler)
742 defer frontend.Close()
743
744 res, err := frontend.Client().Get(frontend.URL)
745 if err != nil {
746 t.Fatal(err)
747 }
748 defer res.Body.Close()
749 if res.StatusCode != 502 {
750 t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
751 }
752 }
753
754
755 func TestReverseProxy_AllocatedHeader(t *testing.T) {
756 proxyHandler := new(ReverseProxy)
757 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
758 proxyHandler.Director = func(*http.Request) {}
759 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
760 if req.Header == nil {
761 t.Error("Header == nil; want a non-nil Header")
762 }
763 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
764 })
765
766 proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
767 Method: "GET",
768 URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
769 Proto: "HTTP/1.0",
770 ProtoMajor: 1,
771 })
772 }
773
774
775
776 func TestReverseProxyModifyResponse(t *testing.T) {
777 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
778 w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
779 }))
780 defer backendServer.Close()
781
782 rpURL, _ := url.Parse(backendServer.URL)
783 rproxy := NewSingleHostReverseProxy(rpURL)
784 rproxy.ErrorLog = log.New(io.Discard, "", 0)
785 rproxy.ModifyResponse = func(resp *http.Response) error {
786 if resp.Header.Get("X-Hit-Mod") != "true" {
787 return fmt.Errorf("tried to by-pass proxy")
788 }
789 return nil
790 }
791
792 frontendProxy := httptest.NewServer(rproxy)
793 defer frontendProxy.Close()
794
795 tests := []struct {
796 url string
797 wantCode int
798 }{
799 {frontendProxy.URL + "/mod", http.StatusOK},
800 {frontendProxy.URL + "/schedule", http.StatusBadGateway},
801 }
802
803 for i, tt := range tests {
804 resp, err := http.Get(tt.url)
805 if err != nil {
806 t.Fatalf("failed to reach proxy: %v", err)
807 }
808 if g, e := resp.StatusCode, tt.wantCode; g != e {
809 t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
810 }
811 resp.Body.Close()
812 }
813 }
814
815 type failingRoundTripper struct{}
816
817 func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
818 return nil, errors.New("some error")
819 }
820
821 type staticResponseRoundTripper struct{ res *http.Response }
822
823 func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
824 return rt.res, nil
825 }
826
827 func TestReverseProxyErrorHandler(t *testing.T) {
828 tests := []struct {
829 name string
830 wantCode int
831 errorHandler func(http.ResponseWriter, *http.Request, error)
832 transport http.RoundTripper
833 modifyResponse func(*http.Response) error
834 }{
835 {
836 name: "default",
837 wantCode: http.StatusBadGateway,
838 },
839 {
840 name: "errorhandler",
841 wantCode: http.StatusTeapot,
842 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
843 },
844 {
845 name: "modifyresponse_noerr",
846 transport: staticResponseRoundTripper{
847 &http.Response{StatusCode: 345, Body: http.NoBody},
848 },
849 modifyResponse: func(res *http.Response) error {
850 res.StatusCode++
851 return nil
852 },
853 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
854 wantCode: 346,
855 },
856 {
857 name: "modifyresponse_err",
858 transport: staticResponseRoundTripper{
859 &http.Response{StatusCode: 345, Body: http.NoBody},
860 },
861 modifyResponse: func(res *http.Response) error {
862 res.StatusCode++
863 return errors.New("some error to trigger errorHandler")
864 },
865 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
866 wantCode: http.StatusTeapot,
867 },
868 }
869
870 for _, tt := range tests {
871 t.Run(tt.name, func(t *testing.T) {
872 target := &url.URL{
873 Scheme: "http",
874 Host: "dummy.tld",
875 Path: "/",
876 }
877 rproxy := NewSingleHostReverseProxy(target)
878 rproxy.Transport = tt.transport
879 rproxy.ModifyResponse = tt.modifyResponse
880 if rproxy.Transport == nil {
881 rproxy.Transport = failingRoundTripper{}
882 }
883 rproxy.ErrorLog = log.New(io.Discard, "", 0)
884 if tt.errorHandler != nil {
885 rproxy.ErrorHandler = tt.errorHandler
886 }
887 frontendProxy := httptest.NewServer(rproxy)
888 defer frontendProxy.Close()
889
890 resp, err := http.Get(frontendProxy.URL + "/test")
891 if err != nil {
892 t.Fatalf("failed to reach proxy: %v", err)
893 }
894 if g, e := resp.StatusCode, tt.wantCode; g != e {
895 t.Errorf("got res.StatusCode %d; expected %d", g, e)
896 }
897 resp.Body.Close()
898 })
899 }
900 }
901
902
903 func TestReverseProxy_CopyBuffer(t *testing.T) {
904 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
905 out := "this call was relayed by the reverse proxy"
906
907 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
908 fmt.Fprintln(w, out)
909 }))
910 defer backendServer.Close()
911
912 rpURL, err := url.Parse(backendServer.URL)
913 if err != nil {
914 t.Fatal(err)
915 }
916
917 var proxyLog bytes.Buffer
918 rproxy := NewSingleHostReverseProxy(rpURL)
919 rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
920 donec := make(chan bool, 1)
921 frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
922 defer func() { donec <- true }()
923 rproxy.ServeHTTP(w, r)
924 }))
925 defer frontendProxy.Close()
926
927 if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
928 t.Fatalf("want non-nil error")
929 }
930
931
932
933
934 <-donec
935
936 expected := []string{
937 "EOF",
938 "read",
939 }
940 for _, phrase := range expected {
941 if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
942 t.Errorf("expected log to contain phrase %q", phrase)
943 }
944 }
945 }
946
947 type staticTransport struct {
948 res *http.Response
949 }
950
951 func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
952 return t.res, nil
953 }
954
955 func BenchmarkServeHTTP(b *testing.B) {
956 res := &http.Response{
957 StatusCode: 200,
958 Body: io.NopCloser(strings.NewReader("")),
959 }
960 proxy := &ReverseProxy{
961 Director: func(*http.Request) {},
962 Transport: &staticTransport{res},
963 }
964
965 w := httptest.NewRecorder()
966 r := httptest.NewRequest("GET", "/", nil)
967
968 b.ReportAllocs()
969 for i := 0; i < b.N; i++ {
970 proxy.ServeHTTP(w, r)
971 }
972 }
973
974 func TestServeHTTPDeepCopy(t *testing.T) {
975 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
976 w.Write([]byte("Hello Gopher!"))
977 }))
978 defer backend.Close()
979 backendURL, err := url.Parse(backend.URL)
980 if err != nil {
981 t.Fatal(err)
982 }
983
984 type result struct {
985 before, after string
986 }
987
988 resultChan := make(chan result, 1)
989 proxyHandler := NewSingleHostReverseProxy(backendURL)
990 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
991 before := r.URL.String()
992 proxyHandler.ServeHTTP(w, r)
993 after := r.URL.String()
994 resultChan <- result{before: before, after: after}
995 }))
996 defer frontend.Close()
997
998 want := result{before: "/", after: "/"}
999
1000 res, err := frontend.Client().Get(frontend.URL)
1001 if err != nil {
1002 t.Fatalf("Do: %v", err)
1003 }
1004 res.Body.Close()
1005
1006 got := <-resultChan
1007 if got != want {
1008 t.Errorf("got = %+v; want = %+v", got, want)
1009 }
1010 }
1011
1012
1013
1014 func TestClonesRequestHeaders(t *testing.T) {
1015 log.SetOutput(io.Discard)
1016 defer log.SetOutput(os.Stderr)
1017 req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
1018 req.RemoteAddr = "1.2.3.4:56789"
1019 rp := &ReverseProxy{
1020 Director: func(req *http.Request) {
1021 req.Header.Set("From-Director", "1")
1022 },
1023 Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
1024 if v := req.Header.Get("From-Director"); v != "1" {
1025 t.Errorf("From-Directory value = %q; want 1", v)
1026 }
1027 return nil, io.EOF
1028 }),
1029 }
1030 rp.ServeHTTP(httptest.NewRecorder(), req)
1031
1032 if req.Header.Get("From-Director") == "1" {
1033 t.Error("Director header mutation modified caller's request")
1034 }
1035 if req.Header.Get("X-Forwarded-For") != "" {
1036 t.Error("X-Forward-For header mutation modified caller's request")
1037 }
1038
1039 }
1040
1041 type roundTripperFunc func(req *http.Request) (*http.Response, error)
1042
1043 func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
1044 return fn(req)
1045 }
1046
1047 func TestModifyResponseClosesBody(t *testing.T) {
1048 req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
1049 req.RemoteAddr = "1.2.3.4:56789"
1050 closeCheck := new(checkCloser)
1051 logBuf := new(bytes.Buffer)
1052 outErr := errors.New("ModifyResponse error")
1053 rp := &ReverseProxy{
1054 Director: func(req *http.Request) {},
1055 Transport: &staticTransport{&http.Response{
1056 StatusCode: 200,
1057 Body: closeCheck,
1058 }},
1059 ErrorLog: log.New(logBuf, "", 0),
1060 ModifyResponse: func(*http.Response) error {
1061 return outErr
1062 },
1063 }
1064 rec := httptest.NewRecorder()
1065 rp.ServeHTTP(rec, req)
1066 res := rec.Result()
1067 if g, e := res.StatusCode, http.StatusBadGateway; g != e {
1068 t.Errorf("got res.StatusCode %d; expected %d", g, e)
1069 }
1070 if !closeCheck.closed {
1071 t.Errorf("body should have been closed")
1072 }
1073 if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
1074 t.Errorf("ErrorLog %q does not contain %q", g, e)
1075 }
1076 }
1077
1078 type checkCloser struct {
1079 closed bool
1080 }
1081
1082 func (cc *checkCloser) Close() error {
1083 cc.closed = true
1084 return nil
1085 }
1086
1087 func (cc *checkCloser) Read(b []byte) (int, error) {
1088 return len(b), nil
1089 }
1090
1091
1092 func TestReverseProxy_PanicBodyError(t *testing.T) {
1093 log.SetOutput(io.Discard)
1094 defer log.SetOutput(os.Stderr)
1095 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1096 out := "this call was relayed by the reverse proxy"
1097
1098 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
1099 fmt.Fprintln(w, out)
1100 }))
1101 defer backendServer.Close()
1102
1103 rpURL, err := url.Parse(backendServer.URL)
1104 if err != nil {
1105 t.Fatal(err)
1106 }
1107
1108 rproxy := NewSingleHostReverseProxy(rpURL)
1109
1110
1111
1112 defer func() {
1113 err := recover()
1114 if err == nil {
1115 t.Fatal("handler should have panicked")
1116 }
1117 if err != http.ErrAbortHandler {
1118 t.Fatal("expected ErrAbortHandler, got", err)
1119 }
1120 }()
1121 req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
1122 rproxy.ServeHTTP(httptest.NewRecorder(), req)
1123 }
1124
1125
1126 func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) {
1127 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1128 out := "this call was relayed by the reverse proxy"
1129
1130 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
1131 fmt.Fprintln(w, out)
1132 }))
1133 defer backend.Close()
1134 backendURL, err := url.Parse(backend.URL)
1135 if err != nil {
1136 t.Fatal(err)
1137 }
1138 proxyHandler := NewSingleHostReverseProxy(backendURL)
1139 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
1140 frontend := httptest.NewServer(proxyHandler)
1141 defer frontend.Close()
1142 frontendClient := frontend.Client()
1143
1144 var wg sync.WaitGroup
1145 for i := 0; i < 2; i++ {
1146 wg.Add(1)
1147 go func() {
1148 defer wg.Done()
1149 for j := 0; j < 10; j++ {
1150 const reqLen = 6 * 1024 * 1024
1151 req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
1152 req.ContentLength = reqLen
1153 resp, _ := frontendClient.Transport.RoundTrip(req)
1154 if resp != nil {
1155 io.Copy(io.Discard, resp.Body)
1156 resp.Body.Close()
1157 }
1158 }
1159 }()
1160 }
1161 wg.Wait()
1162 }
1163
1164 func TestSelectFlushInterval(t *testing.T) {
1165 tests := []struct {
1166 name string
1167 p *ReverseProxy
1168 res *http.Response
1169 want time.Duration
1170 }{
1171 {
1172 name: "default",
1173 res: &http.Response{},
1174 p: &ReverseProxy{FlushInterval: 123},
1175 want: 123,
1176 },
1177 {
1178 name: "server-sent events overrides non-zero",
1179 res: &http.Response{
1180 Header: http.Header{
1181 "Content-Type": {"text/event-stream"},
1182 },
1183 },
1184 p: &ReverseProxy{FlushInterval: 123},
1185 want: -1,
1186 },
1187 {
1188 name: "server-sent events overrides zero",
1189 res: &http.Response{
1190 Header: http.Header{
1191 "Content-Type": {"text/event-stream"},
1192 },
1193 },
1194 p: &ReverseProxy{FlushInterval: 0},
1195 want: -1,
1196 },
1197 {
1198 name: "server-sent events with media-type parameters overrides non-zero",
1199 res: &http.Response{
1200 Header: http.Header{
1201 "Content-Type": {"text/event-stream;charset=utf-8"},
1202 },
1203 },
1204 p: &ReverseProxy{FlushInterval: 123},
1205 want: -1,
1206 },
1207 {
1208 name: "server-sent events with media-type parameters overrides zero",
1209 res: &http.Response{
1210 Header: http.Header{
1211 "Content-Type": {"text/event-stream;charset=utf-8"},
1212 },
1213 },
1214 p: &ReverseProxy{FlushInterval: 0},
1215 want: -1,
1216 },
1217 {
1218 name: "Content-Length: -1, overrides non-zero",
1219 res: &http.Response{
1220 ContentLength: -1,
1221 },
1222 p: &ReverseProxy{FlushInterval: 123},
1223 want: -1,
1224 },
1225 {
1226 name: "Content-Length: -1, overrides zero",
1227 res: &http.Response{
1228 ContentLength: -1,
1229 },
1230 p: &ReverseProxy{FlushInterval: 0},
1231 want: -1,
1232 },
1233 }
1234 for _, tt := range tests {
1235 t.Run(tt.name, func(t *testing.T) {
1236 got := tt.p.flushInterval(tt.res)
1237 if got != tt.want {
1238 t.Errorf("flushLatency = %v; want %v", got, tt.want)
1239 }
1240 })
1241 }
1242 }
1243
1244 func TestReverseProxyWebSocket(t *testing.T) {
1245 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1246 if upgradeType(r.Header) != "websocket" {
1247 t.Error("unexpected backend request")
1248 http.Error(w, "unexpected request", 400)
1249 return
1250 }
1251 c, _, err := w.(http.Hijacker).Hijack()
1252 if err != nil {
1253 t.Error(err)
1254 return
1255 }
1256 defer c.Close()
1257 io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
1258 bs := bufio.NewScanner(c)
1259 if !bs.Scan() {
1260 t.Errorf("backend failed to read line from client: %v", bs.Err())
1261 return
1262 }
1263 fmt.Fprintf(c, "backend got %q\n", bs.Text())
1264 }))
1265 defer backendServer.Close()
1266
1267 backURL, _ := url.Parse(backendServer.URL)
1268 rproxy := NewSingleHostReverseProxy(backURL)
1269 rproxy.ErrorLog = log.New(io.Discard, "", 0)
1270 rproxy.ModifyResponse = func(res *http.Response) error {
1271 res.Header.Add("X-Modified", "true")
1272 return nil
1273 }
1274
1275 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
1276 rw.Header().Set("X-Header", "X-Value")
1277 rproxy.ServeHTTP(rw, req)
1278 if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
1279 t.Errorf("response writer X-Modified header = %q; want %q", got, want)
1280 }
1281 })
1282
1283 frontendProxy := httptest.NewServer(handler)
1284 defer frontendProxy.Close()
1285
1286 req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
1287 req.Header.Set("Connection", "Upgrade")
1288 req.Header.Set("Upgrade", "websocket")
1289
1290 c := frontendProxy.Client()
1291 res, err := c.Do(req)
1292 if err != nil {
1293 t.Fatal(err)
1294 }
1295 if res.StatusCode != 101 {
1296 t.Fatalf("status = %v; want 101", res.Status)
1297 }
1298
1299 got := res.Header.Get("X-Header")
1300 want := "X-Value"
1301 if got != want {
1302 t.Errorf("Header(XHeader) = %q; want %q", got, want)
1303 }
1304
1305 if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
1306 t.Fatalf("not websocket upgrade; got %#v", res.Header)
1307 }
1308 rwc, ok := res.Body.(io.ReadWriteCloser)
1309 if !ok {
1310 t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
1311 }
1312 defer rwc.Close()
1313
1314 if got, want := res.Header.Get("X-Modified"), "true"; got != want {
1315 t.Errorf("response X-Modified header = %q; want %q", got, want)
1316 }
1317
1318 io.WriteString(rwc, "Hello\n")
1319 bs := bufio.NewScanner(rwc)
1320 if !bs.Scan() {
1321 t.Fatalf("Scan: %v", bs.Err())
1322 }
1323 got = bs.Text()
1324 want = `backend got "Hello"`
1325 if got != want {
1326 t.Errorf("got %#q, want %#q", got, want)
1327 }
1328 }
1329
1330 func TestReverseProxyWebSocketCancellation(t *testing.T) {
1331 n := 5
1332 triggerCancelCh := make(chan bool, n)
1333 nthResponse := func(i int) string {
1334 return fmt.Sprintf("backend response #%d\n", i)
1335 }
1336 terminalMsg := "final message"
1337
1338 cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1339 if g, ws := upgradeType(r.Header), "websocket"; g != ws {
1340 t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
1341 http.Error(w, "Unexpected request", 400)
1342 return
1343 }
1344 conn, bufrw, err := w.(http.Hijacker).Hijack()
1345 if err != nil {
1346 t.Error(err)
1347 return
1348 }
1349 defer conn.Close()
1350
1351 upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
1352 if _, err := io.WriteString(conn, upgradeMsg); err != nil {
1353 t.Error(err)
1354 return
1355 }
1356 if _, _, err := bufrw.ReadLine(); err != nil {
1357 t.Errorf("Failed to read line from client: %v", err)
1358 return
1359 }
1360
1361 for i := 0; i < n; i++ {
1362 if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
1363 select {
1364 case <-triggerCancelCh:
1365 default:
1366 t.Errorf("Writing response #%d failed: %v", i, err)
1367 }
1368 return
1369 }
1370 bufrw.Flush()
1371 time.Sleep(time.Second)
1372 }
1373 if _, err := bufrw.WriteString(terminalMsg); err != nil {
1374 select {
1375 case <-triggerCancelCh:
1376 default:
1377 t.Errorf("Failed to write terminal message: %v", err)
1378 }
1379 }
1380 bufrw.Flush()
1381 }))
1382 defer cst.Close()
1383
1384 backendURL, _ := url.Parse(cst.URL)
1385 rproxy := NewSingleHostReverseProxy(backendURL)
1386 rproxy.ErrorLog = log.New(io.Discard, "", 0)
1387 rproxy.ModifyResponse = func(res *http.Response) error {
1388 res.Header.Add("X-Modified", "true")
1389 return nil
1390 }
1391
1392 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
1393 rw.Header().Set("X-Header", "X-Value")
1394 ctx, cancel := context.WithCancel(req.Context())
1395 go func() {
1396 <-triggerCancelCh
1397 cancel()
1398 }()
1399 rproxy.ServeHTTP(rw, req.WithContext(ctx))
1400 })
1401
1402 frontendProxy := httptest.NewServer(handler)
1403 defer frontendProxy.Close()
1404
1405 req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
1406 req.Header.Set("Connection", "Upgrade")
1407 req.Header.Set("Upgrade", "websocket")
1408
1409 res, err := frontendProxy.Client().Do(req)
1410 if err != nil {
1411 t.Fatalf("Dialing to frontend proxy: %v", err)
1412 }
1413 defer res.Body.Close()
1414 if g, w := res.StatusCode, 101; g != w {
1415 t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
1416 }
1417
1418 if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
1419 t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w)
1420 }
1421
1422 if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
1423 t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w)
1424 }
1425
1426 rwc, ok := res.Body.(io.ReadWriteCloser)
1427 if !ok {
1428 t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
1429 }
1430
1431 if got, want := res.Header.Get("X-Modified"), "true"; got != want {
1432 t.Errorf("response X-Modified header = %q; want %q", got, want)
1433 }
1434
1435 if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
1436 t.Fatalf("Failed to write first message: %v", err)
1437 }
1438
1439
1440
1441 br := bufio.NewReader(rwc)
1442 for {
1443 line, err := br.ReadString('\n')
1444 switch {
1445 case line == terminalMsg:
1446 t.Fatalf("The websocket request was not canceled, unfortunately!")
1447
1448 case err == io.EOF:
1449 return
1450
1451 case err != nil:
1452 t.Fatalf("Unexpected error: %v", err)
1453
1454 case line == nthResponse(0):
1455
1456 close(triggerCancelCh)
1457 }
1458 }
1459 }
1460
1461 func TestUnannouncedTrailer(t *testing.T) {
1462 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1463 w.WriteHeader(http.StatusOK)
1464 w.(http.Flusher).Flush()
1465 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
1466 }))
1467 defer backend.Close()
1468 backendURL, err := url.Parse(backend.URL)
1469 if err != nil {
1470 t.Fatal(err)
1471 }
1472 proxyHandler := NewSingleHostReverseProxy(backendURL)
1473 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
1474 frontend := httptest.NewServer(proxyHandler)
1475 defer frontend.Close()
1476 frontendClient := frontend.Client()
1477
1478 res, err := frontendClient.Get(frontend.URL)
1479 if err != nil {
1480 t.Fatalf("Get: %v", err)
1481 }
1482
1483 io.ReadAll(res.Body)
1484
1485 if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
1486 t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
1487 }
1488
1489 }
1490
1491 func TestSingleJoinSlash(t *testing.T) {
1492 tests := []struct {
1493 slasha string
1494 slashb string
1495 expected string
1496 }{
1497 {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
1498 {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
1499 {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
1500 {"https://www.google.com", "", "https://www.google.com/"},
1501 {"", "favicon.ico", "/favicon.ico"},
1502 }
1503 for _, tt := range tests {
1504 if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
1505 t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
1506 tt.slasha,
1507 tt.slashb,
1508 tt.expected,
1509 got)
1510 }
1511 }
1512 }
1513
1514 func TestJoinURLPath(t *testing.T) {
1515 tests := []struct {
1516 a *url.URL
1517 b *url.URL
1518 wantPath string
1519 wantRaw string
1520 }{
1521 {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
1522 {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
1523 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
1524 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
1525 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
1526 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
1527 }
1528
1529 for _, tt := range tests {
1530 p, rp := joinURLPath(tt.a, tt.b)
1531 if p != tt.wantPath || rp != tt.wantRaw {
1532 t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
1533 tt.a.Path, tt.a.RawPath,
1534 tt.b.Path, tt.b.RawPath,
1535 tt.wantPath, tt.wantRaw,
1536 p, rp)
1537 }
1538 }
1539 }
1540
View as plain text