1
2
3
4
5
6
7 package httputil
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "io"
14 "log"
15 "mime"
16 "net"
17 "net/http"
18 "net/http/httptrace"
19 "net/http/internal/ascii"
20 "net/textproto"
21 "net/url"
22 "strings"
23 "sync"
24 "time"
25
26 "golang.org/x/net/http/httpguts"
27 )
28
29
30 type ProxyRequest struct {
31
32
33 In *http.Request
34
35
36
37
38
39 Out *http.Request
40 }
41
42
43
44
45
46
47
48
49
50
51
52
53
54 func (r *ProxyRequest) SetURL(target *url.URL) {
55 rewriteRequestURL(r.Out, target)
56 r.Out.Host = ""
57 }
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 func (r *ProxyRequest) SetXForwarded() {
79 clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
80 if err == nil {
81 prior := r.Out.Header["X-Forwarded-For"]
82 if len(prior) > 0 {
83 clientIP = strings.Join(prior, ", ") + ", " + clientIP
84 }
85 r.Out.Header.Set("X-Forwarded-For", clientIP)
86 } else {
87 r.Out.Header.Del("X-Forwarded-For")
88 }
89 r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
90 if r.In.TLS == nil {
91 r.Out.Header.Set("X-Forwarded-Proto", "http")
92 } else {
93 r.Out.Header.Set("X-Forwarded-Proto", "https")
94 }
95 }
96
97
98
99
100
101
102
103 type ReverseProxy struct {
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125 Rewrite func(*ProxyRequest)
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155 Director func(*http.Request)
156
157
158
159 Transport http.RoundTripper
160
161
162
163
164
165
166
167
168
169
170
171 FlushInterval time.Duration
172
173
174
175
176 ErrorLog *log.Logger
177
178
179
180
181 BufferPool BufferPool
182
183
184
185
186
187
188
189
190
191
192 ModifyResponse func(*http.Response) error
193
194
195
196
197
198
199 ErrorHandler func(http.ResponseWriter, *http.Request, error)
200 }
201
202
203
204 type BufferPool interface {
205 Get() []byte
206 Put([]byte)
207 }
208
209 func singleJoiningSlash(a, b string) string {
210 aslash := strings.HasSuffix(a, "/")
211 bslash := strings.HasPrefix(b, "/")
212 switch {
213 case aslash && bslash:
214 return a + b[1:]
215 case !aslash && !bslash:
216 return a + "/" + b
217 }
218 return a + b
219 }
220
221 func joinURLPath(a, b *url.URL) (path, rawpath string) {
222 if a.RawPath == "" && b.RawPath == "" {
223 return singleJoiningSlash(a.Path, b.Path), ""
224 }
225
226
227 apath := a.EscapedPath()
228 bpath := b.EscapedPath()
229
230 aslash := strings.HasSuffix(apath, "/")
231 bslash := strings.HasPrefix(bpath, "/")
232
233 switch {
234 case aslash && bslash:
235 return a.Path + b.Path[1:], apath + bpath[1:]
236 case !aslash && !bslash:
237 return a.Path + "/" + b.Path, apath + "/" + bpath
238 }
239 return a.Path + b.Path, apath + bpath
240 }
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
263 director := func(req *http.Request) {
264 rewriteRequestURL(req, target)
265 }
266 return &ReverseProxy{Director: director}
267 }
268
269 func rewriteRequestURL(req *http.Request, target *url.URL) {
270 targetQuery := target.RawQuery
271 req.URL.Scheme = target.Scheme
272 req.URL.Host = target.Host
273 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
274 if targetQuery == "" || req.URL.RawQuery == "" {
275 req.URL.RawQuery = targetQuery + req.URL.RawQuery
276 } else {
277 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
278 }
279 }
280
281 func copyHeader(dst, src http.Header) {
282 for k, vv := range src {
283 for _, v := range vv {
284 dst.Add(k, v)
285 }
286 }
287 }
288
289
290
291
292
293
294 var hopHeaders = []string{
295 "Connection",
296 "Proxy-Connection",
297 "Keep-Alive",
298 "Proxy-Authenticate",
299 "Proxy-Authorization",
300 "Te",
301 "Trailer",
302 "Transfer-Encoding",
303 "Upgrade",
304 }
305
306 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
307 p.logf("http: proxy error: %v", err)
308 rw.WriteHeader(http.StatusBadGateway)
309 }
310
311 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
312 if p.ErrorHandler != nil {
313 return p.ErrorHandler
314 }
315 return p.defaultErrorHandler
316 }
317
318
319
320 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
321 if p.ModifyResponse == nil {
322 return true
323 }
324 if err := p.ModifyResponse(res); err != nil {
325 res.Body.Close()
326 p.getErrorHandler()(rw, req, err)
327 return false
328 }
329 return true
330 }
331
332 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
333 transport := p.Transport
334 if transport == nil {
335 transport = http.DefaultTransport
336 }
337
338 ctx := req.Context()
339 if ctx.Done() != nil {
340
341
342
343
344
345
346
347
348
349
350 } else if cn, ok := rw.(http.CloseNotifier); ok {
351 var cancel context.CancelFunc
352 ctx, cancel = context.WithCancel(ctx)
353 defer cancel()
354 notifyChan := cn.CloseNotify()
355 go func() {
356 select {
357 case <-notifyChan:
358 cancel()
359 case <-ctx.Done():
360 }
361 }()
362 }
363
364 outreq := req.Clone(ctx)
365 if req.ContentLength == 0 {
366 outreq.Body = nil
367 }
368 if outreq.Body != nil {
369
370
371
372
373
374
375 defer outreq.Body.Close()
376 }
377 if outreq.Header == nil {
378 outreq.Header = make(http.Header)
379 }
380
381 if (p.Director != nil) == (p.Rewrite != nil) {
382 p.getErrorHandler()(rw, req, errors.New("ReverseProxy must have exactly one of Director or Rewrite set"))
383 return
384 }
385
386 if p.Director != nil {
387 p.Director(outreq)
388 if outreq.Form != nil {
389 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
390 }
391 }
392 outreq.Close = false
393
394 reqUpType := upgradeType(outreq.Header)
395 if !ascii.IsPrint(reqUpType) {
396 p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
397 return
398 }
399 removeHopByHopHeaders(outreq.Header)
400
401
402
403
404
405
406 if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
407 outreq.Header.Set("Te", "trailers")
408 }
409
410
411
412 if reqUpType != "" {
413 outreq.Header.Set("Connection", "Upgrade")
414 outreq.Header.Set("Upgrade", reqUpType)
415 }
416
417 if p.Rewrite != nil {
418
419
420
421 outreq.Header.Del("Forwarded")
422 outreq.Header.Del("X-Forwarded-For")
423 outreq.Header.Del("X-Forwarded-Host")
424 outreq.Header.Del("X-Forwarded-Proto")
425
426
427 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
428
429 pr := &ProxyRequest{
430 In: req,
431 Out: outreq,
432 }
433 p.Rewrite(pr)
434 outreq = pr.Out
435 } else {
436 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
437
438
439
440 prior, ok := outreq.Header["X-Forwarded-For"]
441 omit := ok && prior == nil
442 if len(prior) > 0 {
443 clientIP = strings.Join(prior, ", ") + ", " + clientIP
444 }
445 if !omit {
446 outreq.Header.Set("X-Forwarded-For", clientIP)
447 }
448 }
449 }
450
451 if _, ok := outreq.Header["User-Agent"]; !ok {
452
453
454 outreq.Header.Set("User-Agent", "")
455 }
456
457 trace := &httptrace.ClientTrace{
458 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
459 h := rw.Header()
460 copyHeader(h, http.Header(header))
461 rw.WriteHeader(code)
462
463
464 for k := range h {
465 delete(h, k)
466 }
467
468 return nil
469 },
470 }
471 outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
472
473 res, err := transport.RoundTrip(outreq)
474 if err != nil {
475 p.getErrorHandler()(rw, outreq, err)
476 return
477 }
478
479
480 if res.StatusCode == http.StatusSwitchingProtocols {
481 if !p.modifyResponse(rw, res, outreq) {
482 return
483 }
484 p.handleUpgradeResponse(rw, outreq, res)
485 return
486 }
487
488 removeHopByHopHeaders(res.Header)
489
490 if !p.modifyResponse(rw, res, outreq) {
491 return
492 }
493
494 copyHeader(rw.Header(), res.Header)
495
496
497
498 announcedTrailers := len(res.Trailer)
499 if announcedTrailers > 0 {
500 trailerKeys := make([]string, 0, len(res.Trailer))
501 for k := range res.Trailer {
502 trailerKeys = append(trailerKeys, k)
503 }
504 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
505 }
506
507 rw.WriteHeader(res.StatusCode)
508
509 err = p.copyResponse(rw, res.Body, p.flushInterval(res))
510 if err != nil {
511 defer res.Body.Close()
512
513
514
515 if !shouldPanicOnCopyError(req) {
516 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
517 return
518 }
519 panic(http.ErrAbortHandler)
520 }
521 res.Body.Close()
522
523 if len(res.Trailer) > 0 {
524
525
526
527 if fl, ok := rw.(http.Flusher); ok {
528 fl.Flush()
529 }
530 }
531
532 if len(res.Trailer) == announcedTrailers {
533 copyHeader(rw.Header(), res.Trailer)
534 return
535 }
536
537 for k, vv := range res.Trailer {
538 k = http.TrailerPrefix + k
539 for _, v := range vv {
540 rw.Header().Add(k, v)
541 }
542 }
543 }
544
545 var inOurTests bool
546
547
548
549
550
551
552 func shouldPanicOnCopyError(req *http.Request) bool {
553 if inOurTests {
554
555 return true
556 }
557 if req.Context().Value(http.ServerContextKey) != nil {
558
559
560 return true
561 }
562
563
564 return false
565 }
566
567
568 func removeHopByHopHeaders(h http.Header) {
569
570 for _, f := range h["Connection"] {
571 for _, sf := range strings.Split(f, ",") {
572 if sf = textproto.TrimString(sf); sf != "" {
573 h.Del(sf)
574 }
575 }
576 }
577
578
579
580 for _, f := range hopHeaders {
581 h.Del(f)
582 }
583 }
584
585
586
587 func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
588 resCT := res.Header.Get("Content-Type")
589
590
591
592 if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
593 return -1
594 }
595
596
597 if res.ContentLength == -1 {
598 return -1
599 }
600
601 return p.FlushInterval
602 }
603
604 func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
605 if flushInterval != 0 {
606 if wf, ok := dst.(writeFlusher); ok {
607 mlw := &maxLatencyWriter{
608 dst: wf,
609 latency: flushInterval,
610 }
611 defer mlw.stop()
612
613
614 mlw.flushPending = true
615 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
616
617 dst = mlw
618 }
619 }
620
621 var buf []byte
622 if p.BufferPool != nil {
623 buf = p.BufferPool.Get()
624 defer p.BufferPool.Put(buf)
625 }
626 _, err := p.copyBuffer(dst, src, buf)
627 return err
628 }
629
630
631
632 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
633 if len(buf) == 0 {
634 buf = make([]byte, 32*1024)
635 }
636 var written int64
637 for {
638 nr, rerr := src.Read(buf)
639 if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
640 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
641 }
642 if nr > 0 {
643 nw, werr := dst.Write(buf[:nr])
644 if nw > 0 {
645 written += int64(nw)
646 }
647 if werr != nil {
648 return written, werr
649 }
650 if nr != nw {
651 return written, io.ErrShortWrite
652 }
653 }
654 if rerr != nil {
655 if rerr == io.EOF {
656 rerr = nil
657 }
658 return written, rerr
659 }
660 }
661 }
662
663 func (p *ReverseProxy) logf(format string, args ...any) {
664 if p.ErrorLog != nil {
665 p.ErrorLog.Printf(format, args...)
666 } else {
667 log.Printf(format, args...)
668 }
669 }
670
671 type writeFlusher interface {
672 io.Writer
673 http.Flusher
674 }
675
676 type maxLatencyWriter struct {
677 dst writeFlusher
678 latency time.Duration
679
680 mu sync.Mutex
681 t *time.Timer
682 flushPending bool
683 }
684
685 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
686 m.mu.Lock()
687 defer m.mu.Unlock()
688 n, err = m.dst.Write(p)
689 if m.latency < 0 {
690 m.dst.Flush()
691 return
692 }
693 if m.flushPending {
694 return
695 }
696 if m.t == nil {
697 m.t = time.AfterFunc(m.latency, m.delayedFlush)
698 } else {
699 m.t.Reset(m.latency)
700 }
701 m.flushPending = true
702 return
703 }
704
705 func (m *maxLatencyWriter) delayedFlush() {
706 m.mu.Lock()
707 defer m.mu.Unlock()
708 if !m.flushPending {
709 return
710 }
711 m.dst.Flush()
712 m.flushPending = false
713 }
714
715 func (m *maxLatencyWriter) stop() {
716 m.mu.Lock()
717 defer m.mu.Unlock()
718 m.flushPending = false
719 if m.t != nil {
720 m.t.Stop()
721 }
722 }
723
724 func upgradeType(h http.Header) string {
725 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
726 return ""
727 }
728 return h.Get("Upgrade")
729 }
730
731 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
732 reqUpType := upgradeType(req.Header)
733 resUpType := upgradeType(res.Header)
734 if !ascii.IsPrint(resUpType) {
735 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
736 }
737 if !ascii.EqualFold(reqUpType, resUpType) {
738 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
739 return
740 }
741
742 hj, ok := rw.(http.Hijacker)
743 if !ok {
744 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
745 return
746 }
747 backConn, ok := res.Body.(io.ReadWriteCloser)
748 if !ok {
749 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
750 return
751 }
752
753 backConnCloseCh := make(chan bool)
754 go func() {
755
756
757 select {
758 case <-req.Context().Done():
759 case <-backConnCloseCh:
760 }
761 backConn.Close()
762 }()
763
764 defer close(backConnCloseCh)
765
766 conn, brw, err := hj.Hijack()
767 if err != nil {
768 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
769 return
770 }
771 defer conn.Close()
772
773 copyHeader(rw.Header(), res.Header)
774
775 res.Header = rw.Header()
776 res.Body = nil
777 if err := res.Write(brw); err != nil {
778 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
779 return
780 }
781 if err := brw.Flush(); err != nil {
782 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
783 return
784 }
785 errc := make(chan error, 1)
786 spc := switchProtocolCopier{user: conn, backend: backConn}
787 go spc.copyToBackend(errc)
788 go spc.copyFromBackend(errc)
789 <-errc
790 }
791
792
793
794 type switchProtocolCopier struct {
795 user, backend io.ReadWriter
796 }
797
798 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
799 _, err := io.Copy(c.user, c.backend)
800 errc <- err
801 }
802
803 func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
804 _, err := io.Copy(c.backend, c.user)
805 errc <- err
806 }
807
808 func cleanQueryParams(s string) string {
809 reencode := func(s string) string {
810 v, _ := url.ParseQuery(s)
811 return v.Encode()
812 }
813 for i := 0; i < len(s); {
814 switch s[i] {
815 case ';':
816 return reencode(s)
817 case '%':
818 if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
819 return reencode(s)
820 }
821 i += 3
822 default:
823 i++
824 }
825 }
826 return s
827 }
828
829 func ishex(c byte) bool {
830 switch {
831 case '0' <= c && c <= '9':
832 return true
833 case 'a' <= c && c <= 'f':
834 return true
835 case 'A' <= c && c <= 'F':
836 return true
837 }
838 return false
839 }
840
View as plain text