1
2
3
4
5
6
7 package httputil
8
9 import (
10 "context"
11 "fmt"
12 "io"
13 "log"
14 "mime"
15 "net"
16 "net/http"
17 "net/http/internal/ascii"
18 "net/textproto"
19 "net/url"
20 "strings"
21 "sync"
22 "time"
23
24 "golang.org/x/net/http/httpguts"
25 )
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43 type ReverseProxy struct {
44
45
46
47
48
49
50 Director func(*http.Request)
51
52
53
54 Transport http.RoundTripper
55
56
57
58
59
60
61
62
63
64
65
66 FlushInterval time.Duration
67
68
69
70
71 ErrorLog *log.Logger
72
73
74
75
76 BufferPool BufferPool
77
78
79
80
81
82
83
84
85
86
87 ModifyResponse func(*http.Response) error
88
89
90
91
92
93
94 ErrorHandler func(http.ResponseWriter, *http.Request, error)
95 }
96
97
98
99 type BufferPool interface {
100 Get() []byte
101 Put([]byte)
102 }
103
104 func singleJoiningSlash(a, b string) string {
105 aslash := strings.HasSuffix(a, "/")
106 bslash := strings.HasPrefix(b, "/")
107 switch {
108 case aslash && bslash:
109 return a + b[1:]
110 case !aslash && !bslash:
111 return a + "/" + b
112 }
113 return a + b
114 }
115
116 func joinURLPath(a, b *url.URL) (path, rawpath string) {
117 if a.RawPath == "" && b.RawPath == "" {
118 return singleJoiningSlash(a.Path, b.Path), ""
119 }
120
121
122 apath := a.EscapedPath()
123 bpath := b.EscapedPath()
124
125 aslash := strings.HasSuffix(apath, "/")
126 bslash := strings.HasPrefix(bpath, "/")
127
128 switch {
129 case aslash && bslash:
130 return a.Path + b.Path[1:], apath + bpath[1:]
131 case !aslash && !bslash:
132 return a.Path + "/" + b.Path, apath + "/" + bpath
133 }
134 return a.Path + b.Path, apath + bpath
135 }
136
137
138
139
140
141
142
143
144 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
145 targetQuery := target.RawQuery
146 director := func(req *http.Request) {
147 req.URL.Scheme = target.Scheme
148 req.URL.Host = target.Host
149 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
150 if targetQuery == "" || req.URL.RawQuery == "" {
151 req.URL.RawQuery = targetQuery + req.URL.RawQuery
152 } else {
153 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
154 }
155 if _, ok := req.Header["User-Agent"]; !ok {
156
157 req.Header.Set("User-Agent", "")
158 }
159 }
160 return &ReverseProxy{Director: director}
161 }
162
163 func copyHeader(dst, src http.Header) {
164 for k, vv := range src {
165 for _, v := range vv {
166 dst.Add(k, v)
167 }
168 }
169 }
170
171
172
173
174
175
176 var hopHeaders = []string{
177 "Connection",
178 "Proxy-Connection",
179 "Keep-Alive",
180 "Proxy-Authenticate",
181 "Proxy-Authorization",
182 "Te",
183 "Trailer",
184 "Transfer-Encoding",
185 "Upgrade",
186 }
187
188 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
189 p.logf("http: proxy error: %v", err)
190 rw.WriteHeader(http.StatusBadGateway)
191 }
192
193 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
194 if p.ErrorHandler != nil {
195 return p.ErrorHandler
196 }
197 return p.defaultErrorHandler
198 }
199
200
201
202 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
203 if p.ModifyResponse == nil {
204 return true
205 }
206 if err := p.ModifyResponse(res); err != nil {
207 res.Body.Close()
208 p.getErrorHandler()(rw, req, err)
209 return false
210 }
211 return true
212 }
213
214 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
215 transport := p.Transport
216 if transport == nil {
217 transport = http.DefaultTransport
218 }
219
220 ctx := req.Context()
221 if cn, ok := rw.(http.CloseNotifier); ok {
222 var cancel context.CancelFunc
223 ctx, cancel = context.WithCancel(ctx)
224 defer cancel()
225 notifyChan := cn.CloseNotify()
226 go func() {
227 select {
228 case <-notifyChan:
229 cancel()
230 case <-ctx.Done():
231 }
232 }()
233 }
234
235 outreq := req.Clone(ctx)
236 if req.ContentLength == 0 {
237 outreq.Body = nil
238 }
239 if outreq.Body != nil {
240
241
242
243
244
245
246 defer outreq.Body.Close()
247 }
248 if outreq.Header == nil {
249 outreq.Header = make(http.Header)
250 }
251
252 p.Director(outreq)
253 outreq.Close = false
254
255 reqUpType := upgradeType(outreq.Header)
256 if !ascii.IsPrint(reqUpType) {
257 p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
258 return
259 }
260 removeConnectionHeaders(outreq.Header)
261
262
263
264
265 for _, h := range hopHeaders {
266 outreq.Header.Del(h)
267 }
268
269
270
271
272
273
274 if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
275 outreq.Header.Set("Te", "trailers")
276 }
277
278
279
280 if reqUpType != "" {
281 outreq.Header.Set("Connection", "Upgrade")
282 outreq.Header.Set("Upgrade", reqUpType)
283 }
284
285 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
286
287
288
289 prior, ok := outreq.Header["X-Forwarded-For"]
290 omit := ok && prior == nil
291 if len(prior) > 0 {
292 clientIP = strings.Join(prior, ", ") + ", " + clientIP
293 }
294 if !omit {
295 outreq.Header.Set("X-Forwarded-For", clientIP)
296 }
297 }
298
299 res, err := transport.RoundTrip(outreq)
300 if err != nil {
301 p.getErrorHandler()(rw, outreq, err)
302 return
303 }
304
305
306 if res.StatusCode == http.StatusSwitchingProtocols {
307 if !p.modifyResponse(rw, res, outreq) {
308 return
309 }
310 p.handleUpgradeResponse(rw, outreq, res)
311 return
312 }
313
314 removeConnectionHeaders(res.Header)
315
316 for _, h := range hopHeaders {
317 res.Header.Del(h)
318 }
319
320 if !p.modifyResponse(rw, res, outreq) {
321 return
322 }
323
324 copyHeader(rw.Header(), res.Header)
325
326
327
328 announcedTrailers := len(res.Trailer)
329 if announcedTrailers > 0 {
330 trailerKeys := make([]string, 0, len(res.Trailer))
331 for k := range res.Trailer {
332 trailerKeys = append(trailerKeys, k)
333 }
334 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
335 }
336
337 rw.WriteHeader(res.StatusCode)
338
339 err = p.copyResponse(rw, res.Body, p.flushInterval(res))
340 if err != nil {
341 defer res.Body.Close()
342
343
344
345 if !shouldPanicOnCopyError(req) {
346 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
347 return
348 }
349 panic(http.ErrAbortHandler)
350 }
351 res.Body.Close()
352
353 if len(res.Trailer) > 0 {
354
355
356
357 if fl, ok := rw.(http.Flusher); ok {
358 fl.Flush()
359 }
360 }
361
362 if len(res.Trailer) == announcedTrailers {
363 copyHeader(rw.Header(), res.Trailer)
364 return
365 }
366
367 for k, vv := range res.Trailer {
368 k = http.TrailerPrefix + k
369 for _, v := range vv {
370 rw.Header().Add(k, v)
371 }
372 }
373 }
374
375 var inOurTests bool
376
377
378
379
380
381
382 func shouldPanicOnCopyError(req *http.Request) bool {
383 if inOurTests {
384
385 return true
386 }
387 if req.Context().Value(http.ServerContextKey) != nil {
388
389
390 return true
391 }
392
393
394 return false
395 }
396
397
398
399 func removeConnectionHeaders(h http.Header) {
400 for _, f := range h["Connection"] {
401 for _, sf := range strings.Split(f, ",") {
402 if sf = textproto.TrimString(sf); sf != "" {
403 h.Del(sf)
404 }
405 }
406 }
407 }
408
409
410
411 func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
412 resCT := res.Header.Get("Content-Type")
413
414
415
416 if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
417 return -1
418 }
419
420
421 if res.ContentLength == -1 {
422 return -1
423 }
424
425 return p.FlushInterval
426 }
427
428 func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
429 if flushInterval != 0 {
430 if wf, ok := dst.(writeFlusher); ok {
431 mlw := &maxLatencyWriter{
432 dst: wf,
433 latency: flushInterval,
434 }
435 defer mlw.stop()
436
437
438 mlw.flushPending = true
439 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
440
441 dst = mlw
442 }
443 }
444
445 var buf []byte
446 if p.BufferPool != nil {
447 buf = p.BufferPool.Get()
448 defer p.BufferPool.Put(buf)
449 }
450 _, err := p.copyBuffer(dst, src, buf)
451 return err
452 }
453
454
455
456 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
457 if len(buf) == 0 {
458 buf = make([]byte, 32*1024)
459 }
460 var written int64
461 for {
462 nr, rerr := src.Read(buf)
463 if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
464 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
465 }
466 if nr > 0 {
467 nw, werr := dst.Write(buf[:nr])
468 if nw > 0 {
469 written += int64(nw)
470 }
471 if werr != nil {
472 return written, werr
473 }
474 if nr != nw {
475 return written, io.ErrShortWrite
476 }
477 }
478 if rerr != nil {
479 if rerr == io.EOF {
480 rerr = nil
481 }
482 return written, rerr
483 }
484 }
485 }
486
487 func (p *ReverseProxy) logf(format string, args ...any) {
488 if p.ErrorLog != nil {
489 p.ErrorLog.Printf(format, args...)
490 } else {
491 log.Printf(format, args...)
492 }
493 }
494
495 type writeFlusher interface {
496 io.Writer
497 http.Flusher
498 }
499
500 type maxLatencyWriter struct {
501 dst writeFlusher
502 latency time.Duration
503
504 mu sync.Mutex
505 t *time.Timer
506 flushPending bool
507 }
508
509 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
510 m.mu.Lock()
511 defer m.mu.Unlock()
512 n, err = m.dst.Write(p)
513 if m.latency < 0 {
514 m.dst.Flush()
515 return
516 }
517 if m.flushPending {
518 return
519 }
520 if m.t == nil {
521 m.t = time.AfterFunc(m.latency, m.delayedFlush)
522 } else {
523 m.t.Reset(m.latency)
524 }
525 m.flushPending = true
526 return
527 }
528
529 func (m *maxLatencyWriter) delayedFlush() {
530 m.mu.Lock()
531 defer m.mu.Unlock()
532 if !m.flushPending {
533 return
534 }
535 m.dst.Flush()
536 m.flushPending = false
537 }
538
539 func (m *maxLatencyWriter) stop() {
540 m.mu.Lock()
541 defer m.mu.Unlock()
542 m.flushPending = false
543 if m.t != nil {
544 m.t.Stop()
545 }
546 }
547
548 func upgradeType(h http.Header) string {
549 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
550 return ""
551 }
552 return h.Get("Upgrade")
553 }
554
555 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
556 reqUpType := upgradeType(req.Header)
557 resUpType := upgradeType(res.Header)
558 if !ascii.IsPrint(resUpType) {
559 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
560 }
561 if !ascii.EqualFold(reqUpType, resUpType) {
562 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
563 return
564 }
565
566 hj, ok := rw.(http.Hijacker)
567 if !ok {
568 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
569 return
570 }
571 backConn, ok := res.Body.(io.ReadWriteCloser)
572 if !ok {
573 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
574 return
575 }
576
577 backConnCloseCh := make(chan bool)
578 go func() {
579
580
581 select {
582 case <-req.Context().Done():
583 case <-backConnCloseCh:
584 }
585 backConn.Close()
586 }()
587
588 defer close(backConnCloseCh)
589
590 conn, brw, err := hj.Hijack()
591 if err != nil {
592 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
593 return
594 }
595 defer conn.Close()
596
597 copyHeader(rw.Header(), res.Header)
598
599 res.Header = rw.Header()
600 res.Body = nil
601 if err := res.Write(brw); err != nil {
602 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
603 return
604 }
605 if err := brw.Flush(); err != nil {
606 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
607 return
608 }
609 errc := make(chan error, 1)
610 spc := switchProtocolCopier{user: conn, backend: backConn}
611 go spc.copyToBackend(errc)
612 go spc.copyFromBackend(errc)
613 <-errc
614 return
615 }
616
617
618
619 type switchProtocolCopier struct {
620 user, backend io.ReadWriter
621 }
622
623 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
624 _, err := io.Copy(c.user, c.backend)
625 errc <- err
626 }
627
628 func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
629 _, err := io.Copy(c.backend, c.user)
630 errc <- err
631 }
632
View as plain text