1
2
3
4
5
6
7 package httputil
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "io"
14 "log"
15 "mime"
16 "net"
17 "net/http"
18 "net/http/httptrace"
19 "net/http/internal/ascii"
20 "net/textproto"
21 "net/url"
22 "strings"
23 "sync"
24 "time"
25
26 "golang.org/x/net/http/httpguts"
27 )
28
29
30 type ProxyRequest struct {
31
32
33 In *http.Request
34
35
36
37
38
39 Out *http.Request
40 }
41
42
43
44
45
46
47
48
49
50
51
52
53
54 func (r *ProxyRequest) SetURL(target *url.URL) {
55 rewriteRequestURL(r.Out, target)
56 r.Out.Host = ""
57 }
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 func (r *ProxyRequest) SetXForwarded() {
79 clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
80 if err == nil {
81 prior := r.Out.Header["X-Forwarded-For"]
82 if len(prior) > 0 {
83 clientIP = strings.Join(prior, ", ") + ", " + clientIP
84 }
85 r.Out.Header.Set("X-Forwarded-For", clientIP)
86 } else {
87 r.Out.Header.Del("X-Forwarded-For")
88 }
89 r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
90 if r.In.TLS == nil {
91 r.Out.Header.Set("X-Forwarded-Proto", "http")
92 } else {
93 r.Out.Header.Set("X-Forwarded-Proto", "https")
94 }
95 }
96
97
98
99
100
101
102
103 type ReverseProxy struct {
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125 Rewrite func(*ProxyRequest)
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155 Director func(*http.Request)
156
157
158
159 Transport http.RoundTripper
160
161
162
163
164
165
166
167
168
169
170
171 FlushInterval time.Duration
172
173
174
175
176 ErrorLog *log.Logger
177
178
179
180
181 BufferPool BufferPool
182
183
184
185
186
187
188
189
190
191
192 ModifyResponse func(*http.Response) error
193
194
195
196
197
198
199 ErrorHandler func(http.ResponseWriter, *http.Request, error)
200 }
201
202
203
204 type BufferPool interface {
205 Get() []byte
206 Put([]byte)
207 }
208
209 func singleJoiningSlash(a, b string) string {
210 aslash := strings.HasSuffix(a, "/")
211 bslash := strings.HasPrefix(b, "/")
212 switch {
213 case aslash && bslash:
214 return a + b[1:]
215 case !aslash && !bslash:
216 return a + "/" + b
217 }
218 return a + b
219 }
220
221 func joinURLPath(a, b *url.URL) (path, rawpath string) {
222 if a.RawPath == "" && b.RawPath == "" {
223 return singleJoiningSlash(a.Path, b.Path), ""
224 }
225
226
227 apath := a.EscapedPath()
228 bpath := b.EscapedPath()
229
230 aslash := strings.HasSuffix(apath, "/")
231 bslash := strings.HasPrefix(bpath, "/")
232
233 switch {
234 case aslash && bslash:
235 return a.Path + b.Path[1:], apath + bpath[1:]
236 case !aslash && !bslash:
237 return a.Path + "/" + b.Path, apath + "/" + bpath
238 }
239 return a.Path + b.Path, apath + bpath
240 }
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
263 director := func(req *http.Request) {
264 rewriteRequestURL(req, target)
265 }
266 return &ReverseProxy{Director: director}
267 }
268
269 func rewriteRequestURL(req *http.Request, target *url.URL) {
270 targetQuery := target.RawQuery
271 req.URL.Scheme = target.Scheme
272 req.URL.Host = target.Host
273 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
274 if targetQuery == "" || req.URL.RawQuery == "" {
275 req.URL.RawQuery = targetQuery + req.URL.RawQuery
276 } else {
277 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
278 }
279 }
280
281 func copyHeader(dst, src http.Header) {
282 for k, vv := range src {
283 for _, v := range vv {
284 dst.Add(k, v)
285 }
286 }
287 }
288
289
290
291
292
293
294 var hopHeaders = []string{
295 "Connection",
296 "Proxy-Connection",
297 "Keep-Alive",
298 "Proxy-Authenticate",
299 "Proxy-Authorization",
300 "Te",
301 "Trailer",
302 "Transfer-Encoding",
303 "Upgrade",
304 }
305
306 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
307 p.logf("http: proxy error: %v", err)
308 rw.WriteHeader(http.StatusBadGateway)
309 }
310
311 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
312 if p.ErrorHandler != nil {
313 return p.ErrorHandler
314 }
315 return p.defaultErrorHandler
316 }
317
318
319
320 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
321 if p.ModifyResponse == nil {
322 return true
323 }
324 if err := p.ModifyResponse(res); err != nil {
325 res.Body.Close()
326 p.getErrorHandler()(rw, req, err)
327 return false
328 }
329 return true
330 }
331
332 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
333 transport := p.Transport
334 if transport == nil {
335 transport = http.DefaultTransport
336 }
337
338 ctx := req.Context()
339 if ctx.Done() != nil {
340
341
342
343
344
345
346
347
348
349
350 } else if cn, ok := rw.(http.CloseNotifier); ok {
351 var cancel context.CancelFunc
352 ctx, cancel = context.WithCancel(ctx)
353 defer cancel()
354 notifyChan := cn.CloseNotify()
355 go func() {
356 select {
357 case <-notifyChan:
358 cancel()
359 case <-ctx.Done():
360 }
361 }()
362 }
363
364 outreq := req.Clone(ctx)
365 if req.ContentLength == 0 {
366 outreq.Body = nil
367 }
368 if outreq.Body != nil {
369
370
371
372
373
374
375 defer outreq.Body.Close()
376 }
377 if outreq.Header == nil {
378 outreq.Header = make(http.Header)
379 }
380
381 if (p.Director != nil) == (p.Rewrite != nil) {
382 p.getErrorHandler()(rw, req, errors.New("ReverseProxy must have exactly one of Director or Rewrite set"))
383 return
384 }
385
386 if p.Director != nil {
387 p.Director(outreq)
388 if outreq.Form != nil {
389 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
390 }
391 }
392 outreq.Close = false
393
394 reqUpType := upgradeType(outreq.Header)
395 if !ascii.IsPrint(reqUpType) {
396 p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
397 return
398 }
399 removeHopByHopHeaders(outreq.Header)
400
401
402
403
404
405
406 if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
407 outreq.Header.Set("Te", "trailers")
408 }
409
410
411
412 if reqUpType != "" {
413 outreq.Header.Set("Connection", "Upgrade")
414 outreq.Header.Set("Upgrade", reqUpType)
415 }
416
417 if p.Rewrite != nil {
418
419
420
421 outreq.Header.Del("Forwarded")
422 outreq.Header.Del("X-Forwarded-For")
423 outreq.Header.Del("X-Forwarded-Host")
424 outreq.Header.Del("X-Forwarded-Proto")
425
426
427 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
428
429 pr := &ProxyRequest{
430 In: req,
431 Out: outreq,
432 }
433 p.Rewrite(pr)
434 outreq = pr.Out
435 } else {
436 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
437
438
439
440 prior, ok := outreq.Header["X-Forwarded-For"]
441 omit := ok && prior == nil
442 if len(prior) > 0 {
443 clientIP = strings.Join(prior, ", ") + ", " + clientIP
444 }
445 if !omit {
446 outreq.Header.Set("X-Forwarded-For", clientIP)
447 }
448 }
449 }
450
451 if _, ok := outreq.Header["User-Agent"]; !ok {
452
453
454 outreq.Header.Set("User-Agent", "")
455 }
456
457 trace := &httptrace.ClientTrace{
458 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
459 h := rw.Header()
460 copyHeader(h, http.Header(header))
461 rw.WriteHeader(code)
462
463
464 for k := range h {
465 delete(h, k)
466 }
467
468 return nil
469 },
470 }
471 outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
472
473 res, err := transport.RoundTrip(outreq)
474 if err != nil {
475 p.getErrorHandler()(rw, outreq, err)
476 return
477 }
478
479
480 if res.StatusCode == http.StatusSwitchingProtocols {
481 if !p.modifyResponse(rw, res, outreq) {
482 return
483 }
484 p.handleUpgradeResponse(rw, outreq, res)
485 return
486 }
487
488 removeHopByHopHeaders(res.Header)
489
490 if !p.modifyResponse(rw, res, outreq) {
491 return
492 }
493
494 copyHeader(rw.Header(), res.Header)
495
496
497
498 announcedTrailers := len(res.Trailer)
499 if announcedTrailers > 0 {
500 trailerKeys := make([]string, 0, len(res.Trailer))
501 for k := range res.Trailer {
502 trailerKeys = append(trailerKeys, k)
503 }
504 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
505 }
506
507 rw.WriteHeader(res.StatusCode)
508
509 err = p.copyResponse(rw, res.Body, p.flushInterval(res))
510 if err != nil {
511 defer res.Body.Close()
512
513
514
515 if !shouldPanicOnCopyError(req) {
516 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
517 return
518 }
519 panic(http.ErrAbortHandler)
520 }
521 res.Body.Close()
522
523 if len(res.Trailer) > 0 {
524
525
526
527 http.NewResponseController(rw).Flush()
528 }
529
530 if len(res.Trailer) == announcedTrailers {
531 copyHeader(rw.Header(), res.Trailer)
532 return
533 }
534
535 for k, vv := range res.Trailer {
536 k = http.TrailerPrefix + k
537 for _, v := range vv {
538 rw.Header().Add(k, v)
539 }
540 }
541 }
542
543 var inOurTests bool
544
545
546
547
548
549
550 func shouldPanicOnCopyError(req *http.Request) bool {
551 if inOurTests {
552
553 return true
554 }
555 if req.Context().Value(http.ServerContextKey) != nil {
556
557
558 return true
559 }
560
561
562 return false
563 }
564
565
566 func removeHopByHopHeaders(h http.Header) {
567
568 for _, f := range h["Connection"] {
569 for _, sf := range strings.Split(f, ",") {
570 if sf = textproto.TrimString(sf); sf != "" {
571 h.Del(sf)
572 }
573 }
574 }
575
576
577
578 for _, f := range hopHeaders {
579 h.Del(f)
580 }
581 }
582
583
584
585 func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
586 resCT := res.Header.Get("Content-Type")
587
588
589
590 if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
591 return -1
592 }
593
594
595 if res.ContentLength == -1 {
596 return -1
597 }
598
599 return p.FlushInterval
600 }
601
602 func (p *ReverseProxy) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error {
603 var w io.Writer = dst
604
605 if flushInterval != 0 {
606 mlw := &maxLatencyWriter{
607 dst: dst,
608 flush: http.NewResponseController(dst).Flush,
609 latency: flushInterval,
610 }
611 defer mlw.stop()
612
613
614 mlw.flushPending = true
615 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
616
617 w = mlw
618 }
619
620 var buf []byte
621 if p.BufferPool != nil {
622 buf = p.BufferPool.Get()
623 defer p.BufferPool.Put(buf)
624 }
625 _, err := p.copyBuffer(w, src, buf)
626 return err
627 }
628
629
630
631 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
632 if len(buf) == 0 {
633 buf = make([]byte, 32*1024)
634 }
635 var written int64
636 for {
637 nr, rerr := src.Read(buf)
638 if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
639 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
640 }
641 if nr > 0 {
642 nw, werr := dst.Write(buf[:nr])
643 if nw > 0 {
644 written += int64(nw)
645 }
646 if werr != nil {
647 return written, werr
648 }
649 if nr != nw {
650 return written, io.ErrShortWrite
651 }
652 }
653 if rerr != nil {
654 if rerr == io.EOF {
655 rerr = nil
656 }
657 return written, rerr
658 }
659 }
660 }
661
662 func (p *ReverseProxy) logf(format string, args ...any) {
663 if p.ErrorLog != nil {
664 p.ErrorLog.Printf(format, args...)
665 } else {
666 log.Printf(format, args...)
667 }
668 }
669
670 type maxLatencyWriter struct {
671 dst io.Writer
672 flush func() error
673 latency time.Duration
674
675 mu sync.Mutex
676 t *time.Timer
677 flushPending bool
678 }
679
680 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
681 m.mu.Lock()
682 defer m.mu.Unlock()
683 n, err = m.dst.Write(p)
684 if m.latency < 0 {
685 m.flush()
686 return
687 }
688 if m.flushPending {
689 return
690 }
691 if m.t == nil {
692 m.t = time.AfterFunc(m.latency, m.delayedFlush)
693 } else {
694 m.t.Reset(m.latency)
695 }
696 m.flushPending = true
697 return
698 }
699
700 func (m *maxLatencyWriter) delayedFlush() {
701 m.mu.Lock()
702 defer m.mu.Unlock()
703 if !m.flushPending {
704 return
705 }
706 m.flush()
707 m.flushPending = false
708 }
709
710 func (m *maxLatencyWriter) stop() {
711 m.mu.Lock()
712 defer m.mu.Unlock()
713 m.flushPending = false
714 if m.t != nil {
715 m.t.Stop()
716 }
717 }
718
719 func upgradeType(h http.Header) string {
720 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
721 return ""
722 }
723 return h.Get("Upgrade")
724 }
725
726 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
727 reqUpType := upgradeType(req.Header)
728 resUpType := upgradeType(res.Header)
729 if !ascii.IsPrint(resUpType) {
730 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
731 }
732 if !ascii.EqualFold(reqUpType, resUpType) {
733 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
734 return
735 }
736
737 backConn, ok := res.Body.(io.ReadWriteCloser)
738 if !ok {
739 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
740 return
741 }
742
743 rc := http.NewResponseController(rw)
744 conn, brw, hijackErr := rc.Hijack()
745 if errors.Is(hijackErr, http.ErrNotSupported) {
746 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
747 return
748 }
749
750 backConnCloseCh := make(chan bool)
751 go func() {
752
753
754 select {
755 case <-req.Context().Done():
756 case <-backConnCloseCh:
757 }
758 backConn.Close()
759 }()
760 defer close(backConnCloseCh)
761
762 if hijackErr != nil {
763 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr))
764 return
765 }
766 defer conn.Close()
767
768 copyHeader(rw.Header(), res.Header)
769
770 res.Header = rw.Header()
771 res.Body = nil
772 if err := res.Write(brw); err != nil {
773 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
774 return
775 }
776 if err := brw.Flush(); err != nil {
777 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
778 return
779 }
780 errc := make(chan error, 1)
781 spc := switchProtocolCopier{user: conn, backend: backConn}
782 go spc.copyToBackend(errc)
783 go spc.copyFromBackend(errc)
784 <-errc
785 }
786
787
788
789 type switchProtocolCopier struct {
790 user, backend io.ReadWriter
791 }
792
793 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
794 _, err := io.Copy(c.user, c.backend)
795 errc <- err
796 }
797
798 func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
799 _, err := io.Copy(c.backend, c.user)
800 errc <- err
801 }
802
803 func cleanQueryParams(s string) string {
804 reencode := func(s string) string {
805 v, _ := url.ParseQuery(s)
806 return v.Encode()
807 }
808 for i := 0; i < len(s); {
809 switch s[i] {
810 case ';':
811 return reencode(s)
812 case '%':
813 if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
814 return reencode(s)
815 }
816 i += 3
817 default:
818 i++
819 }
820 }
821 return s
822 }
823
824 func ishex(c byte) bool {
825 switch {
826 case '0' <= c && c <= '9':
827 return true
828 case 'a' <= c && c <= 'f':
829 return true
830 case 'A' <= c && c <= 'F':
831 return true
832 }
833 return false
834 }
835
View as plain text