Source file
src/net/http/transport_test.go
1
2
3
4
5
6
7
8
9
10 package http_test
11
12 import (
13 "bufio"
14 "bytes"
15 "compress/gzip"
16 "context"
17 "crypto/rand"
18 "crypto/tls"
19 "crypto/x509"
20 "encoding/binary"
21 "errors"
22 "fmt"
23 "go/token"
24 "internal/nettrace"
25 "io"
26 "log"
27 mrand "math/rand"
28 "net"
29 . "net/http"
30 "net/http/httptest"
31 "net/http/httptrace"
32 "net/http/httputil"
33 "net/http/internal/testcert"
34 "net/textproto"
35 "net/url"
36 "os"
37 "reflect"
38 "runtime"
39 "strconv"
40 "strings"
41 "sync"
42 "sync/atomic"
43 "testing"
44 "testing/iotest"
45 "time"
46
47 "golang.org/x/net/http/httpguts"
48 )
49
50
51
52
53
54 var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
55 if r.FormValue("close") == "true" {
56 w.Header().Set("Connection", "close")
57 }
58 w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
59 w.Write([]byte(r.RemoteAddr))
60
61
62
63 if c, ok := ResponseWriterConnForTesting(w); ok {
64 fmt.Fprintf(w, ", %T %p", c, c)
65 }
66 })
67
68
69 type testCloseConn struct {
70 net.Conn
71 set *testConnSet
72 }
73
74 func (c *testCloseConn) Close() error {
75 c.set.remove(c)
76 return c.Conn.Close()
77 }
78
79
80
81 type testConnSet struct {
82 t *testing.T
83 mu sync.Mutex
84 closed map[net.Conn]bool
85 list []net.Conn
86 }
87
88 func (tcs *testConnSet) insert(c net.Conn) {
89 tcs.mu.Lock()
90 defer tcs.mu.Unlock()
91 tcs.closed[c] = false
92 tcs.list = append(tcs.list, c)
93 }
94
95 func (tcs *testConnSet) remove(c net.Conn) {
96 tcs.mu.Lock()
97 defer tcs.mu.Unlock()
98 tcs.closed[c] = true
99 }
100
101
102 func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
103 connSet := &testConnSet{
104 t: t,
105 closed: make(map[net.Conn]bool),
106 }
107 dial := func(n, addr string) (net.Conn, error) {
108 c, err := net.Dial(n, addr)
109 if err != nil {
110 return nil, err
111 }
112 tc := &testCloseConn{c, connSet}
113 connSet.insert(tc)
114 return tc, nil
115 }
116 return connSet, dial
117 }
118
119 func (tcs *testConnSet) check(t *testing.T) {
120 tcs.mu.Lock()
121 defer tcs.mu.Unlock()
122 for i := 4; i >= 0; i-- {
123 for i, c := range tcs.list {
124 if tcs.closed[c] {
125 continue
126 }
127 if i != 0 {
128
129
130 tcs.mu.Unlock()
131 time.Sleep(50 * time.Millisecond)
132 tcs.mu.Lock()
133 continue
134 }
135 t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
136 }
137 }
138 }
139
140 func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
141 func testReuseRequest(t *testing.T, mode testMode) {
142 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
143 w.Write([]byte("{}"))
144 })).ts
145
146 c := ts.Client()
147 req, _ := NewRequest("GET", ts.URL, nil)
148 res, err := c.Do(req)
149 if err != nil {
150 t.Fatal(err)
151 }
152 err = res.Body.Close()
153 if err != nil {
154 t.Fatal(err)
155 }
156
157 res, err = c.Do(req)
158 if err != nil {
159 t.Fatal(err)
160 }
161 err = res.Body.Close()
162 if err != nil {
163 t.Fatal(err)
164 }
165 }
166
167
168
169 func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
170 func testTransportKeepAlives(t *testing.T, mode testMode) {
171 ts := newClientServerTest(t, mode, hostPortHandler).ts
172
173 c := ts.Client()
174 for _, disableKeepAlive := range []bool{false, true} {
175 c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
176 fetch := func(n int) string {
177 res, err := c.Get(ts.URL)
178 if err != nil {
179 t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
180 }
181 body, err := io.ReadAll(res.Body)
182 if err != nil {
183 t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
184 }
185 return string(body)
186 }
187
188 body1 := fetch(1)
189 body2 := fetch(2)
190
191 bodiesDiffer := body1 != body2
192 if bodiesDiffer != disableKeepAlive {
193 t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
194 disableKeepAlive, bodiesDiffer, body1, body2)
195 }
196 }
197 }
198
199 func TestTransportConnectionCloseOnResponse(t *testing.T) {
200 run(t, testTransportConnectionCloseOnResponse)
201 }
202 func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
203 ts := newClientServerTest(t, mode, hostPortHandler).ts
204
205 connSet, testDial := makeTestDial(t)
206
207 c := ts.Client()
208 tr := c.Transport.(*Transport)
209 tr.Dial = testDial
210
211 for _, connectionClose := range []bool{false, true} {
212 fetch := func(n int) string {
213 req := new(Request)
214 var err error
215 req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
216 if err != nil {
217 t.Fatalf("URL parse error: %v", err)
218 }
219 req.Method = "GET"
220 req.Proto = "HTTP/1.1"
221 req.ProtoMajor = 1
222 req.ProtoMinor = 1
223
224 res, err := c.Do(req)
225 if err != nil {
226 t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
227 }
228 defer res.Body.Close()
229 body, err := io.ReadAll(res.Body)
230 if err != nil {
231 t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
232 }
233 return string(body)
234 }
235
236 body1 := fetch(1)
237 body2 := fetch(2)
238 bodiesDiffer := body1 != body2
239 if bodiesDiffer != connectionClose {
240 t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
241 connectionClose, bodiesDiffer, body1, body2)
242 }
243
244 tr.CloseIdleConnections()
245 }
246
247 connSet.check(t)
248 }
249
250
251
252
253
254
255
256 func TestTransportConnectionCloseOnRequest(t *testing.T) {
257 run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
258 }
259 func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
260 ts := newClientServerTest(t, mode, hostPortHandler).ts
261
262 connSet, testDial := makeTestDial(t)
263
264 c := ts.Client()
265 tr := c.Transport.(*Transport)
266 tr.Dial = testDial
267 for _, reqClose := range []bool{false, true} {
268 fetch := func(n int) string {
269 req := new(Request)
270 var err error
271 req.URL, err = url.Parse(ts.URL)
272 if err != nil {
273 t.Fatalf("URL parse error: %v", err)
274 }
275 req.Method = "GET"
276 req.Proto = "HTTP/1.1"
277 req.ProtoMajor = 1
278 req.ProtoMinor = 1
279 req.Close = reqClose
280
281 res, err := c.Do(req)
282 if err != nil {
283 t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
284 }
285 if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
286 t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
287 reqClose, got, !reqClose)
288 }
289 body, err := io.ReadAll(res.Body)
290 if err != nil {
291 t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
292 }
293 return string(body)
294 }
295
296 body1 := fetch(1)
297 body2 := fetch(2)
298
299 got := 1
300 if body1 != body2 {
301 got++
302 }
303 want := 1
304 if reqClose {
305 want = 2
306 }
307 if got != want {
308 t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
309 reqClose, got, want, body1, body2)
310 }
311
312 tr.CloseIdleConnections()
313 }
314
315 connSet.check(t)
316 }
317
318
319
320
321 func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
322 run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
323 }
324 func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
325 ts := newClientServerTest(t, mode, hostPortHandler).ts
326
327 c := ts.Client()
328 c.Transport.(*Transport).DisableKeepAlives = true
329
330 res, err := c.Get(ts.URL)
331 if err != nil {
332 t.Fatal(err)
333 }
334 res.Body.Close()
335 if res.Header.Get("X-Saw-Close") != "true" {
336 t.Errorf("handler didn't see Connection: close ")
337 }
338 }
339
340
341
342 func TestTransportRespectRequestWantsClose(t *testing.T) {
343 run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
344 }
345 func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
346 tests := []struct {
347 disableKeepAlives bool
348 close bool
349 }{
350 {disableKeepAlives: false, close: false},
351 {disableKeepAlives: false, close: true},
352 {disableKeepAlives: true, close: false},
353 {disableKeepAlives: true, close: true},
354 }
355
356 for _, tc := range tests {
357 t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
358 func(t *testing.T) {
359 ts := newClientServerTest(t, mode, hostPortHandler).ts
360
361 c := ts.Client()
362 c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
363 req, err := NewRequest("GET", ts.URL, nil)
364 if err != nil {
365 t.Fatal(err)
366 }
367 count := 0
368 trace := &httptrace.ClientTrace{
369 WroteHeaderField: func(key string, field []string) {
370 if key != "Connection" {
371 return
372 }
373 if httpguts.HeaderValuesContainsToken(field, "close") {
374 count += 1
375 }
376 },
377 }
378 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
379 req.Close = tc.close
380 res, err := c.Do(req)
381 if err != nil {
382 t.Fatal(err)
383 }
384 defer res.Body.Close()
385 if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
386 t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
387 }
388 })
389 }
390
391 }
392
393 func TestTransportIdleCacheKeys(t *testing.T) {
394 run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
395 }
396 func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
397 ts := newClientServerTest(t, mode, hostPortHandler).ts
398 c := ts.Client()
399 tr := c.Transport.(*Transport)
400
401 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
402 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
403 }
404
405 resp, err := c.Get(ts.URL)
406 if err != nil {
407 t.Error(err)
408 }
409 io.ReadAll(resp.Body)
410
411 keys := tr.IdleConnKeysForTesting()
412 if e, g := 1, len(keys); e != g {
413 t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
414 }
415
416 if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
417 t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
418 }
419
420 tr.CloseIdleConnections()
421 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
422 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
423 }
424 }
425
426
427
428 func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
429 func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
430 const msg = "foobar"
431
432 var addrSeen map[string]int
433 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
434 addrSeen[r.RemoteAddr]++
435 if r.URL.Path == "/chunked/" {
436 w.WriteHeader(200)
437 w.(Flusher).Flush()
438 } else {
439 w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
440 w.WriteHeader(200)
441 }
442 w.Write([]byte(msg))
443 })).ts
444
445 for pi, path := range []string{"/content-length/", "/chunked/"} {
446 wantLen := []int{len(msg), -1}[pi]
447 addrSeen = make(map[string]int)
448 for i := 0; i < 3; i++ {
449 res, err := ts.Client().Get(ts.URL + path)
450 if err != nil {
451 t.Errorf("Get %s: %v", path, err)
452 continue
453 }
454
455
456
457
458
459 defer res.Body.Close()
460
461 if res.ContentLength != int64(wantLen) {
462 t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
463 }
464 got, err := io.ReadAll(res.Body)
465 if string(got) != msg || err != nil {
466 t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
467 }
468 }
469 if len(addrSeen) != 1 {
470 t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
471 }
472 }
473 }
474
475 func TestTransportMaxPerHostIdleConns(t *testing.T) {
476 run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
477 }
478 func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
479 stop := make(chan struct{})
480 defer close(stop)
481
482 resch := make(chan string)
483 gotReq := make(chan bool)
484 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
485 gotReq <- true
486 var msg string
487 select {
488 case <-stop:
489 return
490 case msg = <-resch:
491 }
492 _, err := w.Write([]byte(msg))
493 if err != nil {
494 t.Errorf("Write: %v", err)
495 return
496 }
497 })).ts
498
499 c := ts.Client()
500 tr := c.Transport.(*Transport)
501 maxIdleConnsPerHost := 2
502 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
503
504
505
506 donech := make(chan bool)
507 doReq := func() {
508 defer func() {
509 select {
510 case <-stop:
511 return
512 case donech <- t.Failed():
513 }
514 }()
515 resp, err := c.Get(ts.URL)
516 if err != nil {
517 t.Error(err)
518 return
519 }
520 if _, err := io.ReadAll(resp.Body); err != nil {
521 t.Errorf("ReadAll: %v", err)
522 return
523 }
524 }
525 go doReq()
526 <-gotReq
527 go doReq()
528 <-gotReq
529 go doReq()
530 <-gotReq
531
532 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
533 t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
534 }
535
536 resch <- "res1"
537 <-donech
538 keys := tr.IdleConnKeysForTesting()
539 if e, g := 1, len(keys); e != g {
540 t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
541 }
542 addr := ts.Listener.Addr().String()
543 cacheKey := "|http|" + addr
544 if keys[0] != cacheKey {
545 t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
546 }
547 if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
548 t.Errorf("after first response, expected %d idle conns; got %d", e, g)
549 }
550
551 resch <- "res2"
552 <-donech
553 if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
554 t.Errorf("after second response, idle conns = %d; want %d", g, w)
555 }
556
557 resch <- "res3"
558 <-donech
559 if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
560 t.Errorf("after third response, idle conns = %d; want %d", g, w)
561 }
562 }
563
564 func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
565 run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
566 }
567 func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
568 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
569 _, err := w.Write([]byte("foo"))
570 if err != nil {
571 t.Fatalf("Write: %v", err)
572 }
573 })).ts
574 c := ts.Client()
575 tr := c.Transport.(*Transport)
576 dialStarted := make(chan struct{})
577 stallDial := make(chan struct{})
578 tr.Dial = func(network, addr string) (net.Conn, error) {
579 dialStarted <- struct{}{}
580 <-stallDial
581 return net.Dial(network, addr)
582 }
583
584 tr.DisableKeepAlives = true
585 tr.MaxConnsPerHost = 1
586
587 preDial := make(chan struct{})
588 reqComplete := make(chan struct{})
589 doReq := func(reqId string) {
590 req, _ := NewRequest("GET", ts.URL, nil)
591 trace := &httptrace.ClientTrace{
592 GetConn: func(hostPort string) {
593 preDial <- struct{}{}
594 },
595 }
596 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
597 resp, err := tr.RoundTrip(req)
598 if err != nil {
599 t.Errorf("unexpected error for request %s: %v", reqId, err)
600 }
601 _, err = io.ReadAll(resp.Body)
602 if err != nil {
603 t.Errorf("unexpected error for request %s: %v", reqId, err)
604 }
605 reqComplete <- struct{}{}
606 }
607
608 go doReq("req1")
609 <-preDial
610 <-dialStarted
611
612
613 go doReq("req2")
614 <-preDial
615 select {
616 case <-dialStarted:
617 t.Error("req2 dial started while req1 dial in progress")
618 return
619 default:
620 }
621
622
623 stallDial <- struct{}{}
624 <-reqComplete
625
626
627 <-dialStarted
628 stallDial <- struct{}{}
629 <-reqComplete
630 }
631
632 func TestTransportMaxConnsPerHost(t *testing.T) {
633 run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
634 }
635 func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
636 CondSkipHTTP2(t)
637
638 h := HandlerFunc(func(w ResponseWriter, r *Request) {
639 _, err := w.Write([]byte("foo"))
640 if err != nil {
641 t.Fatalf("Write: %v", err)
642 }
643 })
644
645 ts := newClientServerTest(t, mode, h).ts
646 c := ts.Client()
647 tr := c.Transport.(*Transport)
648 tr.MaxConnsPerHost = 1
649
650 mu := sync.Mutex{}
651 var conns []net.Conn
652 var dialCnt, gotConnCnt, tlsHandshakeCnt int32
653 tr.Dial = func(network, addr string) (net.Conn, error) {
654 atomic.AddInt32(&dialCnt, 1)
655 c, err := net.Dial(network, addr)
656 mu.Lock()
657 defer mu.Unlock()
658 conns = append(conns, c)
659 return c, err
660 }
661
662 doReq := func() {
663 trace := &httptrace.ClientTrace{
664 GotConn: func(connInfo httptrace.GotConnInfo) {
665 if !connInfo.Reused {
666 atomic.AddInt32(&gotConnCnt, 1)
667 }
668 },
669 TLSHandshakeStart: func() {
670 atomic.AddInt32(&tlsHandshakeCnt, 1)
671 },
672 }
673 req, _ := NewRequest("GET", ts.URL, nil)
674 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
675
676 resp, err := c.Do(req)
677 if err != nil {
678 t.Fatalf("request failed: %v", err)
679 }
680 defer resp.Body.Close()
681 _, err = io.ReadAll(resp.Body)
682 if err != nil {
683 t.Fatalf("read body failed: %v", err)
684 }
685 }
686
687 wg := sync.WaitGroup{}
688 for i := 0; i < 10; i++ {
689 wg.Add(1)
690 go func() {
691 defer wg.Done()
692 doReq()
693 }()
694 }
695 wg.Wait()
696
697 expected := int32(tr.MaxConnsPerHost)
698 if dialCnt != expected {
699 t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
700 }
701 if gotConnCnt != expected {
702 t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
703 }
704 if ts.TLS != nil && tlsHandshakeCnt != expected {
705 t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
706 }
707
708 if t.Failed() {
709 t.FailNow()
710 }
711
712 mu.Lock()
713 for _, c := range conns {
714 c.Close()
715 }
716 conns = nil
717 mu.Unlock()
718 tr.CloseIdleConnections()
719
720 doReq()
721 expected++
722 if dialCnt != expected {
723 t.Errorf("round 2: too many dials: %d", dialCnt)
724 }
725 if gotConnCnt != expected {
726 t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
727 }
728 if ts.TLS != nil && tlsHandshakeCnt != expected {
729 t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
730 }
731 }
732
733 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
734 run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
735 }
736 func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
737 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
738 io.WriteString(w, r.RemoteAddr)
739 })).ts
740
741 c := ts.Client()
742 tr := c.Transport.(*Transport)
743
744 doReq := func(name string) {
745
746
747 res, err := c.Post(ts.URL, "", nil)
748 if err != nil {
749 t.Fatalf("%s: %v", name, err)
750 }
751 if res.StatusCode != 200 {
752 t.Fatalf("%s: %v", name, res.Status)
753 }
754 defer res.Body.Close()
755 slurp, err := io.ReadAll(res.Body)
756 if err != nil {
757 t.Fatalf("%s: %v", name, err)
758 }
759 t.Logf("%s: ok (%q)", name, slurp)
760 }
761
762 doReq("first")
763 keys1 := tr.IdleConnKeysForTesting()
764
765 ts.CloseClientConnections()
766
767 var keys2 []string
768 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
769 keys2 = tr.IdleConnKeysForTesting()
770 if len(keys2) != 0 {
771 if d > 0 {
772 t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
773 }
774 return false
775 }
776 return true
777 })
778
779 doReq("second")
780 }
781
782
783
784 func TestTransportServerClosingUnexpectedly(t *testing.T) {
785 run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
786 }
787 func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
788 ts := newClientServerTest(t, mode, hostPortHandler).ts
789 c := ts.Client()
790
791 fetch := func(n, retries int) string {
792 condFatalf := func(format string, arg ...any) {
793 if retries <= 0 {
794 t.Fatalf(format, arg...)
795 }
796 t.Logf("retrying shortly after expected error: "+format, arg...)
797 time.Sleep(time.Second / time.Duration(retries))
798 }
799 for retries >= 0 {
800 retries--
801 res, err := c.Get(ts.URL)
802 if err != nil {
803 condFatalf("error in req #%d, GET: %v", n, err)
804 continue
805 }
806 body, err := io.ReadAll(res.Body)
807 if err != nil {
808 condFatalf("error in req #%d, ReadAll: %v", n, err)
809 continue
810 }
811 res.Body.Close()
812 return string(body)
813 }
814 panic("unreachable")
815 }
816
817 body1 := fetch(1, 0)
818 body2 := fetch(2, 0)
819
820
821
822
823
824
825
826
827 ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
828
829 body3 := fetch(3, 5)
830
831 if body1 != body2 {
832 t.Errorf("expected body1 and body2 to be equal")
833 }
834 if body2 == body3 {
835 t.Errorf("expected body2 and body3 to be different")
836 }
837 }
838
839
840
841 func TestStressSurpriseServerCloses(t *testing.T) {
842 run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
843 }
844 func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
845 if testing.Short() {
846 t.Skip("skipping test in short mode")
847 }
848 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
849 w.Header().Set("Content-Length", "5")
850 w.Header().Set("Content-Type", "text/plain")
851 w.Write([]byte("Hello"))
852 w.(Flusher).Flush()
853 conn, buf, _ := w.(Hijacker).Hijack()
854 buf.Flush()
855 conn.Close()
856 })).ts
857 c := ts.Client()
858
859
860
861
862
863
864
865 const (
866 numClients = 20
867 reqsPerClient = 25
868 )
869 var wg sync.WaitGroup
870 wg.Add(numClients * reqsPerClient)
871 for i := 0; i < numClients; i++ {
872 go func() {
873 for i := 0; i < reqsPerClient; i++ {
874 res, err := c.Get(ts.URL)
875 if err == nil {
876
877
878
879
880
881
882 res.Body.Close()
883 }
884 wg.Done()
885 }
886 }()
887 }
888
889
890 wg.Wait()
891 }
892
893
894
895 func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
896 func testTransportHeadResponses(t *testing.T, mode testMode) {
897 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
898 if r.Method != "HEAD" {
899 panic("expected HEAD; got " + r.Method)
900 }
901 w.Header().Set("Content-Length", "123")
902 w.WriteHeader(200)
903 })).ts
904 c := ts.Client()
905
906 for i := 0; i < 2; i++ {
907 res, err := c.Head(ts.URL)
908 if err != nil {
909 t.Errorf("error on loop %d: %v", i, err)
910 continue
911 }
912 if e, g := "123", res.Header.Get("Content-Length"); e != g {
913 t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
914 }
915 if e, g := int64(123), res.ContentLength; e != g {
916 t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
917 }
918 if all, err := io.ReadAll(res.Body); err != nil {
919 t.Errorf("loop %d: Body ReadAll: %v", i, err)
920 } else if len(all) != 0 {
921 t.Errorf("Bogus body %q", all)
922 }
923 }
924 }
925
926
927
928 func TestTransportHeadChunkedResponse(t *testing.T) {
929 run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
930 }
931 func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
932 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
933 if r.Method != "HEAD" {
934 panic("expected HEAD; got " + r.Method)
935 }
936 w.Header().Set("Transfer-Encoding", "chunked")
937 w.Header().Set("x-client-ipport", r.RemoteAddr)
938 w.WriteHeader(200)
939 })).ts
940 c := ts.Client()
941
942
943
944 didRead := make(chan bool)
945 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
946 defer SetReadLoopBeforeNextReadHook(nil)
947
948 res1, err := c.Head(ts.URL)
949 <-didRead
950
951 if err != nil {
952 t.Fatalf("request 1 error: %v", err)
953 }
954
955 res2, err := c.Head(ts.URL)
956 <-didRead
957
958 if err != nil {
959 t.Fatalf("request 2 error: %v", err)
960 }
961 if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
962 t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
963 }
964 }
965
966 var roundTripTests = []struct {
967 accept string
968 expectAccept string
969 compressed bool
970 }{
971
972 {"", "gzip", false},
973
974 {"foo", "foo", false},
975
976 {"gzip", "gzip", true},
977 }
978
979
980 func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
981 func testRoundTripGzip(t *testing.T, mode testMode) {
982 const responseBody = "test response body"
983 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
984 accept := req.Header.Get("Accept-Encoding")
985 if expect := req.FormValue("expect_accept"); accept != expect {
986 t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
987 req.FormValue("testnum"), accept, expect)
988 }
989 if accept == "gzip" {
990 rw.Header().Set("Content-Encoding", "gzip")
991 gz := gzip.NewWriter(rw)
992 gz.Write([]byte(responseBody))
993 gz.Close()
994 } else {
995 rw.Header().Set("Content-Encoding", accept)
996 rw.Write([]byte(responseBody))
997 }
998 })).ts
999 tr := ts.Client().Transport.(*Transport)
1000
1001 for i, test := range roundTripTests {
1002
1003 req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
1004 if test.accept != "" {
1005 req.Header.Set("Accept-Encoding", test.accept)
1006 }
1007 res, err := tr.RoundTrip(req)
1008 if err != nil {
1009 t.Errorf("%d. RoundTrip: %v", i, err)
1010 continue
1011 }
1012 var body []byte
1013 if test.compressed {
1014 var r *gzip.Reader
1015 r, err = gzip.NewReader(res.Body)
1016 if err != nil {
1017 t.Errorf("%d. gzip NewReader: %v", i, err)
1018 continue
1019 }
1020 body, err = io.ReadAll(r)
1021 res.Body.Close()
1022 } else {
1023 body, err = io.ReadAll(res.Body)
1024 }
1025 if err != nil {
1026 t.Errorf("%d. Error: %q", i, err)
1027 continue
1028 }
1029 if g, e := string(body), responseBody; g != e {
1030 t.Errorf("%d. body = %q; want %q", i, g, e)
1031 }
1032 if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
1033 t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
1034 }
1035 if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
1036 t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
1037 }
1038 }
1039
1040 }
1041
1042 func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
1043 func testTransportGzip(t *testing.T, mode testMode) {
1044 if mode == http2Mode {
1045 t.Skip("https://go.dev/issue/56020")
1046 }
1047 const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
1048 const nRandBytes = 1024 * 1024
1049 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1050 if req.Method == "HEAD" {
1051 if g := req.Header.Get("Accept-Encoding"); g != "" {
1052 t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
1053 }
1054 return
1055 }
1056 if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
1057 t.Errorf("Accept-Encoding = %q, want %q", g, e)
1058 }
1059 rw.Header().Set("Content-Encoding", "gzip")
1060
1061 var w io.Writer = rw
1062 var buf bytes.Buffer
1063 if req.FormValue("chunked") == "0" {
1064 w = &buf
1065 defer io.Copy(rw, &buf)
1066 defer func() {
1067 rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
1068 }()
1069 }
1070 gz := gzip.NewWriter(w)
1071 gz.Write([]byte(testString))
1072 if req.FormValue("body") == "large" {
1073 io.CopyN(gz, rand.Reader, nRandBytes)
1074 }
1075 gz.Close()
1076 })).ts
1077 c := ts.Client()
1078
1079 for _, chunked := range []string{"1", "0"} {
1080
1081 res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
1082 if err != nil {
1083 t.Fatalf("large get: %v", err)
1084 }
1085 buf := make([]byte, len(testString))
1086 n, err := io.ReadFull(res.Body, buf)
1087 if err != nil {
1088 t.Fatalf("partial read of large response: size=%d, %v", n, err)
1089 }
1090 if e, g := testString, string(buf); e != g {
1091 t.Errorf("partial read got %q, expected %q", g, e)
1092 }
1093 res.Body.Close()
1094
1095 n, err = res.Body.Read(buf)
1096 if n != 0 || err == nil {
1097 t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
1098 }
1099
1100
1101 res, err = c.Get(ts.URL + "/?chunked=" + chunked)
1102 if err != nil {
1103 t.Fatal(err)
1104 }
1105 body, err := io.ReadAll(res.Body)
1106 if err != nil {
1107 t.Fatal(err)
1108 }
1109 if g, e := string(body), testString; g != e {
1110 t.Fatalf("body = %q; want %q", g, e)
1111 }
1112 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1113 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1114 }
1115
1116
1117 n, err = res.Body.Read(buf)
1118 if n != 0 || err == nil {
1119 t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
1120 }
1121 res.Body.Close()
1122 n, err = res.Body.Read(buf)
1123 if n != 0 || err == nil {
1124 t.Errorf("expected Read error after Close; got %d, %v", n, err)
1125 }
1126 }
1127
1128
1129 res, err := c.Head(ts.URL)
1130 if err != nil {
1131 t.Fatalf("Head: %v", err)
1132 }
1133 if res.StatusCode != 200 {
1134 t.Errorf("Head status=%d; want=200", res.StatusCode)
1135 }
1136 }
1137
1138
1139
1140 func TestTransportExpect100Continue(t *testing.T) {
1141 run(t, testTransportExpect100Continue, []testMode{http1Mode})
1142 }
1143 func testTransportExpect100Continue(t *testing.T, mode testMode) {
1144 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1145 switch req.URL.Path {
1146 case "/100":
1147
1148 if _, err := io.Copy(io.Discard, req.Body); err != nil {
1149 t.Error("Failed to read Body", err)
1150 }
1151 rw.WriteHeader(StatusOK)
1152 case "/200":
1153
1154
1155 rw.WriteHeader(StatusOK)
1156 case "/500":
1157 rw.WriteHeader(StatusInternalServerError)
1158 case "/keepalive":
1159
1160 _, bufrw, err := rw.(Hijacker).Hijack()
1161 if err != nil {
1162 log.Fatal(err)
1163 }
1164 bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n")
1165 bufrw.WriteString("Content-Length: 0\r\n\r\n")
1166 bufrw.Flush()
1167 case "/timeout":
1168
1169
1170 conn, bufrw, err := rw.(Hijacker).Hijack()
1171 if err != nil {
1172 log.Fatal(err)
1173 }
1174 if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil {
1175 t.Error("Failed to read Body", err)
1176 }
1177 bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n")
1178 bufrw.Flush()
1179 conn.Close()
1180 }
1181
1182 })).ts
1183
1184 tests := []struct {
1185 path string
1186 body []byte
1187 sent int
1188 status int
1189 }{
1190 {path: "/100", body: []byte("hello"), sent: 5, status: 200},
1191 {path: "/200", body: []byte("hello"), sent: 0, status: 200},
1192 {path: "/500", body: []byte("hello"), sent: 0, status: 500},
1193 {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500},
1194 {path: "/timeout", body: []byte("hello"), sent: 5, status: 200},
1195 }
1196
1197 c := ts.Client()
1198 for i, v := range tests {
1199 tr := &Transport{
1200 ExpectContinueTimeout: 2 * time.Second,
1201 }
1202 defer tr.CloseIdleConnections()
1203 c.Transport = tr
1204 body := bytes.NewReader(v.body)
1205 req, err := NewRequest("PUT", ts.URL+v.path, body)
1206 if err != nil {
1207 t.Fatal(err)
1208 }
1209 req.Header.Set("Expect", "100-continue")
1210 req.ContentLength = int64(len(v.body))
1211
1212 resp, err := c.Do(req)
1213 if err != nil {
1214 t.Fatal(err)
1215 }
1216 resp.Body.Close()
1217
1218 sent := len(v.body) - body.Len()
1219 if v.status != resp.StatusCode {
1220 t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path)
1221 }
1222 if v.sent != sent {
1223 t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path)
1224 }
1225 }
1226 }
1227
1228 func TestSOCKS5Proxy(t *testing.T) {
1229 run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
1230 }
1231 func testSOCKS5Proxy(t *testing.T, mode testMode) {
1232 ch := make(chan string, 1)
1233 l := newLocalListener(t)
1234 defer l.Close()
1235 defer close(ch)
1236 proxy := func(t *testing.T) {
1237 s, err := l.Accept()
1238 if err != nil {
1239 t.Errorf("socks5 proxy Accept(): %v", err)
1240 return
1241 }
1242 defer s.Close()
1243 var buf [22]byte
1244 if _, err := io.ReadFull(s, buf[:3]); err != nil {
1245 t.Errorf("socks5 proxy initial read: %v", err)
1246 return
1247 }
1248 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1249 t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
1250 return
1251 }
1252 if _, err := s.Write([]byte{5, 0}); err != nil {
1253 t.Errorf("socks5 proxy initial write: %v", err)
1254 return
1255 }
1256 if _, err := io.ReadFull(s, buf[:4]); err != nil {
1257 t.Errorf("socks5 proxy second read: %v", err)
1258 return
1259 }
1260 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1261 t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
1262 return
1263 }
1264 var ipLen int
1265 switch buf[3] {
1266 case 1:
1267 ipLen = net.IPv4len
1268 case 4:
1269 ipLen = net.IPv6len
1270 default:
1271 t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
1272 return
1273 }
1274 if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
1275 t.Errorf("socks5 proxy address read: %v", err)
1276 return
1277 }
1278 ip := net.IP(buf[4 : ipLen+4])
1279 port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
1280 copy(buf[:3], []byte{5, 0, 0})
1281 if _, err := s.Write(buf[:ipLen+6]); err != nil {
1282 t.Errorf("socks5 proxy connect write: %v", err)
1283 return
1284 }
1285 ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
1286
1287
1288 targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
1289 targetConn, err := net.Dial("tcp", targetHost)
1290 if err != nil {
1291 t.Errorf("net.Dial failed")
1292 return
1293 }
1294 go io.Copy(targetConn, s)
1295 io.Copy(s, targetConn)
1296 targetConn.Close()
1297 }
1298
1299 pu, err := url.Parse("socks5://" + l.Addr().String())
1300 if err != nil {
1301 t.Fatal(err)
1302 }
1303
1304 sentinelHeader := "X-Sentinel"
1305 sentinelValue := "12345"
1306 h := HandlerFunc(func(w ResponseWriter, r *Request) {
1307 w.Header().Set(sentinelHeader, sentinelValue)
1308 })
1309 for _, useTLS := range []bool{false, true} {
1310 t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
1311 ts := newClientServerTest(t, mode, h).ts
1312 go proxy(t)
1313 c := ts.Client()
1314 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1315 r, err := c.Head(ts.URL)
1316 if err != nil {
1317 t.Fatal(err)
1318 }
1319 if r.Header.Get(sentinelHeader) != sentinelValue {
1320 t.Errorf("Failed to retrieve sentinel value")
1321 }
1322 got := <-ch
1323 ts.Close()
1324 tsu, err := url.Parse(ts.URL)
1325 if err != nil {
1326 t.Fatal(err)
1327 }
1328 want := "proxy for " + tsu.Host
1329 if got != want {
1330 t.Errorf("got %q, want %q", got, want)
1331 }
1332 })
1333 }
1334 }
1335
1336 func TestTransportProxy(t *testing.T) {
1337 defer afterTest(t)
1338 testCases := []struct{ siteMode, proxyMode testMode }{
1339 {http1Mode, http1Mode},
1340 {http1Mode, https1Mode},
1341 {https1Mode, http1Mode},
1342 {https1Mode, https1Mode},
1343 }
1344 for _, testCase := range testCases {
1345 siteMode := testCase.siteMode
1346 proxyMode := testCase.proxyMode
1347 t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
1348 siteCh := make(chan *Request, 1)
1349 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1350 siteCh <- r
1351 })
1352 proxyCh := make(chan *Request, 1)
1353 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1354 proxyCh <- r
1355
1356 if r.Method == "CONNECT" {
1357 hijacker, ok := w.(Hijacker)
1358 if !ok {
1359 t.Errorf("hijack not allowed")
1360 return
1361 }
1362 clientConn, _, err := hijacker.Hijack()
1363 if err != nil {
1364 t.Errorf("hijacking failed")
1365 return
1366 }
1367 res := &Response{
1368 StatusCode: StatusOK,
1369 Proto: "HTTP/1.1",
1370 ProtoMajor: 1,
1371 ProtoMinor: 1,
1372 Header: make(Header),
1373 }
1374
1375 targetConn, err := net.Dial("tcp", r.URL.Host)
1376 if err != nil {
1377 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1378 return
1379 }
1380
1381 if err := res.Write(clientConn); err != nil {
1382 t.Errorf("Writing 200 OK failed: %v", err)
1383 return
1384 }
1385
1386 go io.Copy(targetConn, clientConn)
1387 go func() {
1388 io.Copy(clientConn, targetConn)
1389 targetConn.Close()
1390 }()
1391 }
1392 })
1393 ts := newClientServerTest(t, siteMode, h1).ts
1394 proxy := newClientServerTest(t, proxyMode, h2).ts
1395
1396 pu, err := url.Parse(proxy.URL)
1397 if err != nil {
1398 t.Fatal(err)
1399 }
1400
1401
1402
1403
1404 c := proxy.Client()
1405 if siteMode == https1Mode {
1406 c = ts.Client()
1407 }
1408
1409 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1410 if _, err := c.Head(ts.URL); err != nil {
1411 t.Error(err)
1412 }
1413 got := <-proxyCh
1414 c.Transport.(*Transport).CloseIdleConnections()
1415 ts.Close()
1416 proxy.Close()
1417 if siteMode == https1Mode {
1418
1419 if got.Method != "CONNECT" {
1420 t.Errorf("Wrong method for secure proxying: %q", got.Method)
1421 }
1422 gotHost := got.URL.Host
1423 pu, err := url.Parse(ts.URL)
1424 if err != nil {
1425 t.Fatal("Invalid site URL")
1426 }
1427 if wantHost := pu.Host; gotHost != wantHost {
1428 t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
1429 }
1430
1431
1432 next := <-siteCh
1433 if next.Method != "HEAD" {
1434 t.Errorf("Wrong method at destination: %s", next.Method)
1435 }
1436 if nextURL := next.URL.String(); nextURL != "/" {
1437 t.Errorf("Wrong URL at destination: %s", nextURL)
1438 }
1439 } else {
1440 if got.Method != "HEAD" {
1441 t.Errorf("Wrong method for destination: %q", got.Method)
1442 }
1443 gotURL := got.URL.String()
1444 wantURL := ts.URL + "/"
1445 if gotURL != wantURL {
1446 t.Errorf("Got URL %q, want %q", gotURL, wantURL)
1447 }
1448 }
1449 })
1450 }
1451 }
1452
1453 func TestOnProxyConnectResponse(t *testing.T) {
1454
1455 var tcases = []struct {
1456 proxyStatusCode int
1457 err error
1458 }{
1459 {
1460 StatusOK,
1461 nil,
1462 },
1463 {
1464 StatusForbidden,
1465 errors.New("403"),
1466 },
1467 }
1468 for _, tcase := range tcases {
1469 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1470
1471 })
1472
1473 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1474
1475 if r.Method == "CONNECT" {
1476 if tcase.proxyStatusCode != StatusOK {
1477 w.WriteHeader(tcase.proxyStatusCode)
1478 return
1479 }
1480 hijacker, ok := w.(Hijacker)
1481 if !ok {
1482 t.Errorf("hijack not allowed")
1483 return
1484 }
1485 clientConn, _, err := hijacker.Hijack()
1486 if err != nil {
1487 t.Errorf("hijacking failed")
1488 return
1489 }
1490 res := &Response{
1491 StatusCode: StatusOK,
1492 Proto: "HTTP/1.1",
1493 ProtoMajor: 1,
1494 ProtoMinor: 1,
1495 Header: make(Header),
1496 }
1497
1498 targetConn, err := net.Dial("tcp", r.URL.Host)
1499 if err != nil {
1500 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1501 return
1502 }
1503
1504 if err := res.Write(clientConn); err != nil {
1505 t.Errorf("Writing 200 OK failed: %v", err)
1506 return
1507 }
1508
1509 go io.Copy(targetConn, clientConn)
1510 go func() {
1511 io.Copy(clientConn, targetConn)
1512 targetConn.Close()
1513 }()
1514 }
1515 })
1516 ts := newClientServerTest(t, https1Mode, h1).ts
1517 proxy := newClientServerTest(t, https1Mode, h2).ts
1518
1519 pu, err := url.Parse(proxy.URL)
1520 if err != nil {
1521 t.Fatal(err)
1522 }
1523
1524 c := proxy.Client()
1525
1526 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1527 c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
1528 if proxyURL.String() != pu.String() {
1529 t.Errorf("proxy url got %s, want %s", proxyURL, pu)
1530 }
1531
1532 if "https://"+connectReq.URL.String() != ts.URL {
1533 t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
1534 }
1535 return tcase.err
1536 }
1537 if _, err := c.Head(ts.URL); err != nil {
1538 if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
1539 t.Errorf("got %v, want %v", err, tcase.err)
1540 }
1541 }
1542 }
1543 }
1544
1545
1546
1547 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
1548 setParallel(t)
1549 defer afterTest(t)
1550
1551 ctx, cancel := context.WithCancel(context.Background())
1552 defer cancel()
1553
1554 ln := newLocalListener(t)
1555 defer ln.Close()
1556 listenerDone := make(chan struct{})
1557 go func() {
1558 defer close(listenerDone)
1559 c, err := ln.Accept()
1560 if err != nil {
1561 t.Errorf("Accept: %v", err)
1562 return
1563 }
1564 defer c.Close()
1565
1566 br := bufio.NewReader(c)
1567 cr, err := ReadRequest(br)
1568 if err != nil {
1569 t.Errorf("proxy server failed to read CONNECT request")
1570 return
1571 }
1572 if cr.Method != "CONNECT" {
1573 t.Errorf("unexpected method %q", cr.Method)
1574 return
1575 }
1576
1577
1578
1579
1580 cancel()
1581 var buf [1]byte
1582 _, err = br.Read(buf[:])
1583 if err != io.EOF {
1584 t.Errorf("proxy server Read err = %v; want EOF", err)
1585 }
1586 return
1587 }()
1588
1589 c := &Client{
1590 Transport: &Transport{
1591 Proxy: func(*Request) (*url.URL, error) {
1592 return url.Parse("http://" + ln.Addr().String())
1593 },
1594 },
1595 }
1596 req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil)
1597 if err != nil {
1598 t.Fatal(err)
1599 }
1600 _, err = c.Do(req)
1601 if err == nil {
1602 t.Errorf("unexpected Get success")
1603 }
1604
1605
1606
1607
1608 <-listenerDone
1609 }
1610
1611
1612 func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
1613 defer afterTest(t)
1614
1615 var errDial = errors.New("some dial error")
1616
1617 tr := &Transport{
1618 Proxy: func(*Request) (*url.URL, error) {
1619 return url.Parse("http://proxy.fake.tld/")
1620 },
1621 Dial: func(string, string) (net.Conn, error) {
1622 return nil, errDial
1623 },
1624 }
1625 defer tr.CloseIdleConnections()
1626
1627 c := &Client{Transport: tr}
1628 req, _ := NewRequest("GET", "http://fake.tld", nil)
1629 res, err := c.Do(req)
1630 if err == nil {
1631 res.Body.Close()
1632 t.Fatal("wanted a non-nil error")
1633 }
1634
1635 uerr, ok := err.(*url.Error)
1636 if !ok {
1637 t.Fatalf("got %T, want *url.Error", err)
1638 }
1639 oe, ok := uerr.Err.(*net.OpError)
1640 if !ok {
1641 t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
1642 }
1643 want := &net.OpError{
1644 Op: "proxyconnect",
1645 Net: "tcp",
1646 Err: errDial,
1647 }
1648 if !reflect.DeepEqual(oe, want) {
1649 t.Errorf("Got error %#v; want %#v", oe, want)
1650 }
1651 }
1652
1653
1654
1655
1656
1657 func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
1658 run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
1659 }
1660 func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
1661 proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
1662 defer proxy.Close()
1663 c := proxy.Client()
1664
1665 tr := c.Transport.(*Transport)
1666 tr.Proxy = func(*Request) (*url.URL, error) {
1667 u, _ := url.Parse(proxy.URL)
1668 u.User = url.UserPassword("aladdin", "opensesame")
1669 return u, nil
1670 }
1671 h := tr.ProxyConnectHeader
1672 if h == nil {
1673 h = make(Header)
1674 }
1675 tr.ProxyConnectHeader = h.Clone()
1676
1677 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1678 if err != nil {
1679 t.Fatal(err)
1680 }
1681 _, err = c.Do(req)
1682 if err == nil {
1683 t.Errorf("unexpected Get success")
1684 }
1685
1686 if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
1687 t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
1688 }
1689 }
1690
1691
1692
1693
1694
1695 func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
1696 func testTransportGzipRecursive(t *testing.T, mode testMode) {
1697 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1698 w.Header().Set("Content-Encoding", "gzip")
1699 w.Write(rgz)
1700 })).ts
1701
1702 c := ts.Client()
1703 res, err := c.Get(ts.URL)
1704 if err != nil {
1705 t.Fatal(err)
1706 }
1707 body, err := io.ReadAll(res.Body)
1708 if err != nil {
1709 t.Fatal(err)
1710 }
1711 if !bytes.Equal(body, rgz) {
1712 t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
1713 body, rgz)
1714 }
1715 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1716 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1717 }
1718 }
1719
1720
1721
1722 func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
1723 func testTransportGzipShort(t *testing.T, mode testMode) {
1724 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1725 w.Header().Set("Content-Encoding", "gzip")
1726 w.Write([]byte{0x1f, 0x8b})
1727 })).ts
1728
1729 c := ts.Client()
1730 res, err := c.Get(ts.URL)
1731 if err != nil {
1732 t.Fatal(err)
1733 }
1734 defer res.Body.Close()
1735 _, err = io.ReadAll(res.Body)
1736 if err == nil {
1737 t.Fatal("Expect an error from reading a body.")
1738 }
1739 if err != io.ErrUnexpectedEOF {
1740 t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
1741 }
1742 }
1743
1744
1745 func waitNumGoroutine(nmax int) int {
1746 nfinal := runtime.NumGoroutine()
1747 for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
1748 time.Sleep(50 * time.Millisecond)
1749 runtime.GC()
1750 nfinal = runtime.NumGoroutine()
1751 }
1752 return nfinal
1753 }
1754
1755
1756 func TestTransportPersistConnLeak(t *testing.T) {
1757 run(t, testTransportPersistConnLeak, testNotParallel)
1758 }
1759 func testTransportPersistConnLeak(t *testing.T, mode testMode) {
1760 if mode == http2Mode {
1761 t.Skip("flaky in HTTP/2")
1762 }
1763
1764
1765 const numReq = 25
1766 gotReqCh := make(chan bool, numReq)
1767 unblockCh := make(chan bool, numReq)
1768 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1769 gotReqCh <- true
1770 <-unblockCh
1771 w.Header().Set("Content-Length", "0")
1772 w.WriteHeader(204)
1773 })).ts
1774 c := ts.Client()
1775 tr := c.Transport.(*Transport)
1776
1777 n0 := runtime.NumGoroutine()
1778
1779 didReqCh := make(chan bool, numReq)
1780 failed := make(chan bool, numReq)
1781 for i := 0; i < numReq; i++ {
1782 go func() {
1783 res, err := c.Get(ts.URL)
1784 didReqCh <- true
1785 if err != nil {
1786 t.Logf("client fetch error: %v", err)
1787 failed <- true
1788 return
1789 }
1790 res.Body.Close()
1791 }()
1792 }
1793
1794
1795 for i := 0; i < numReq; i++ {
1796 select {
1797 case <-gotReqCh:
1798
1799 case <-failed:
1800
1801
1802 }
1803 }
1804
1805 nhigh := runtime.NumGoroutine()
1806
1807
1808 close(unblockCh)
1809
1810
1811 for i := 0; i < numReq; i++ {
1812 <-didReqCh
1813 }
1814
1815 tr.CloseIdleConnections()
1816 nfinal := waitNumGoroutine(n0 + 5)
1817
1818 growth := nfinal - n0
1819
1820
1821
1822 if int(growth) > 5 {
1823 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
1824 t.Error("too many new goroutines")
1825 }
1826 }
1827
1828
1829
1830 func TestTransportPersistConnLeakShortBody(t *testing.T) {
1831 run(t, testTransportPersistConnLeakShortBody, testNotParallel)
1832 }
1833 func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
1834 if mode == http2Mode {
1835 t.Skip("flaky in HTTP/2")
1836 }
1837
1838
1839 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1840 })).ts
1841 c := ts.Client()
1842 tr := c.Transport.(*Transport)
1843
1844 n0 := runtime.NumGoroutine()
1845 body := []byte("Hello")
1846 for i := 0; i < 20; i++ {
1847 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
1848 if err != nil {
1849 t.Fatal(err)
1850 }
1851 req.ContentLength = int64(len(body) - 2)
1852 _, err = c.Do(req)
1853 if err == nil {
1854 t.Fatal("Expect an error from writing too long of a body.")
1855 }
1856 }
1857 nhigh := runtime.NumGoroutine()
1858 tr.CloseIdleConnections()
1859 nfinal := waitNumGoroutine(n0 + 5)
1860
1861 growth := nfinal - n0
1862
1863
1864
1865 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
1866 if int(growth) > 5 {
1867 t.Error("too many new goroutines")
1868 }
1869 }
1870
1871
1872 type countedConn struct {
1873 net.Conn
1874 }
1875
1876
1877 type countingDialer struct {
1878 dialer net.Dialer
1879 mu sync.Mutex
1880 total, live int64
1881 }
1882
1883 func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
1884 conn, err := d.dialer.DialContext(ctx, network, address)
1885 if err != nil {
1886 return nil, err
1887 }
1888
1889 counted := new(countedConn)
1890 counted.Conn = conn
1891
1892 d.mu.Lock()
1893 defer d.mu.Unlock()
1894 d.total++
1895 d.live++
1896
1897 runtime.SetFinalizer(counted, d.decrement)
1898 return counted, nil
1899 }
1900
1901 func (d *countingDialer) decrement(*countedConn) {
1902 d.mu.Lock()
1903 defer d.mu.Unlock()
1904 d.live--
1905 }
1906
1907 func (d *countingDialer) Read() (total, live int64) {
1908 d.mu.Lock()
1909 defer d.mu.Unlock()
1910 return d.total, d.live
1911 }
1912
1913 func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
1914 run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
1915 }
1916 func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
1917 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1918
1919 conn, _, err := w.(Hijacker).Hijack()
1920 if err != nil {
1921 t.Errorf("Hijack failed unexpectedly: %v", err)
1922 return
1923 }
1924 conn.Close()
1925 })).ts
1926
1927 var d countingDialer
1928 c := ts.Client()
1929 c.Transport.(*Transport).DialContext = d.DialContext
1930
1931 body := []byte("Hello")
1932 for i := 0; ; i++ {
1933 total, live := d.Read()
1934 if live < total {
1935 break
1936 }
1937 if i >= 1<<12 {
1938 t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
1939 }
1940
1941 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
1942 if err != nil {
1943 t.Fatal(err)
1944 }
1945 _, err = c.Do(req)
1946 if err == nil {
1947 t.Fatal("expected broken connection")
1948 }
1949
1950 runtime.GC()
1951 }
1952 }
1953
1954 type countedContext struct {
1955 context.Context
1956 }
1957
1958 type contextCounter struct {
1959 mu sync.Mutex
1960 live int64
1961 }
1962
1963 func (cc *contextCounter) Track(ctx context.Context) context.Context {
1964 counted := new(countedContext)
1965 counted.Context = ctx
1966 cc.mu.Lock()
1967 defer cc.mu.Unlock()
1968 cc.live++
1969 runtime.SetFinalizer(counted, cc.decrement)
1970 return counted
1971 }
1972
1973 func (cc *contextCounter) decrement(*countedContext) {
1974 cc.mu.Lock()
1975 defer cc.mu.Unlock()
1976 cc.live--
1977 }
1978
1979 func (cc *contextCounter) Read() (live int64) {
1980 cc.mu.Lock()
1981 defer cc.mu.Unlock()
1982 return cc.live
1983 }
1984
1985 func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
1986 run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
1987 }
1988 func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
1989 if mode == http2Mode {
1990 t.Skip("https://go.dev/issue/56021")
1991 }
1992
1993 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1994 runtime.Gosched()
1995 w.WriteHeader(StatusOK)
1996 })).ts
1997
1998 c := ts.Client()
1999 c.Transport.(*Transport).MaxConnsPerHost = 1
2000
2001 ctx := context.Background()
2002 body := []byte("Hello")
2003 doPosts := func(cc *contextCounter) {
2004 var wg sync.WaitGroup
2005 for n := 64; n > 0; n-- {
2006 wg.Add(1)
2007 go func() {
2008 defer wg.Done()
2009
2010 ctx := cc.Track(ctx)
2011 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2012 if err != nil {
2013 t.Error(err)
2014 }
2015
2016 _, err = c.Do(req.WithContext(ctx))
2017 if err != nil {
2018 t.Errorf("Do failed with error: %v", err)
2019 }
2020 }()
2021 }
2022 wg.Wait()
2023 }
2024
2025 var initialCC contextCounter
2026 doPosts(&initialCC)
2027
2028
2029
2030
2031 var flushCC contextCounter
2032 for i := 0; ; i++ {
2033 live := initialCC.Read()
2034 if live == 0 {
2035 break
2036 }
2037 if i >= 100 {
2038 t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
2039 }
2040 doPosts(&flushCC)
2041 runtime.GC()
2042 }
2043 }
2044
2045
2046 func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
2047 func testTransportIdleConnCrash(t *testing.T, mode testMode) {
2048 var tr *Transport
2049
2050 unblockCh := make(chan bool, 1)
2051 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2052 <-unblockCh
2053 tr.CloseIdleConnections()
2054 })).ts
2055 c := ts.Client()
2056 tr = c.Transport.(*Transport)
2057
2058 didreq := make(chan bool)
2059 go func() {
2060 res, err := c.Get(ts.URL)
2061 if err != nil {
2062 t.Error(err)
2063 } else {
2064 res.Body.Close()
2065 }
2066 didreq <- true
2067 }()
2068 unblockCh <- true
2069 <-didreq
2070 }
2071
2072
2073
2074
2075
2076 func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
2077 func testIssue3644(t *testing.T, mode testMode) {
2078 const numFoos = 5000
2079 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2080 w.Header().Set("Connection", "close")
2081 for i := 0; i < numFoos; i++ {
2082 w.Write([]byte("foo "))
2083 }
2084 })).ts
2085 c := ts.Client()
2086 res, err := c.Get(ts.URL)
2087 if err != nil {
2088 t.Fatal(err)
2089 }
2090 defer res.Body.Close()
2091 bs, err := io.ReadAll(res.Body)
2092 if err != nil {
2093 t.Fatal(err)
2094 }
2095 if len(bs) != numFoos*len("foo ") {
2096 t.Errorf("unexpected response length")
2097 }
2098 }
2099
2100
2101
2102 func TestIssue3595(t *testing.T) { run(t, testIssue3595) }
2103 func testIssue3595(t *testing.T, mode testMode) {
2104 const deniedMsg = "sorry, denied."
2105 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2106 Error(w, deniedMsg, StatusUnauthorized)
2107 })).ts
2108 c := ts.Client()
2109 res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
2110 if err != nil {
2111 t.Errorf("Post: %v", err)
2112 return
2113 }
2114 got, err := io.ReadAll(res.Body)
2115 if err != nil {
2116 t.Fatalf("Body ReadAll: %v", err)
2117 }
2118 if !strings.Contains(string(got), deniedMsg) {
2119 t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
2120 }
2121 }
2122
2123
2124
2125 func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
2126 func testChunkedNoContent(t *testing.T, mode testMode) {
2127 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2128 w.WriteHeader(StatusNoContent)
2129 })).ts
2130
2131 c := ts.Client()
2132 for _, closeBody := range []bool{true, false} {
2133 const n = 4
2134 for i := 1; i <= n; i++ {
2135 res, err := c.Get(ts.URL)
2136 if err != nil {
2137 t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
2138 } else {
2139 if closeBody {
2140 res.Body.Close()
2141 }
2142 }
2143 }
2144 }
2145 }
2146
2147 func TestTransportConcurrency(t *testing.T) {
2148 run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
2149 }
2150 func testTransportConcurrency(t *testing.T, mode testMode) {
2151
2152 maxProcs, numReqs := 16, 500
2153 if testing.Short() {
2154 maxProcs, numReqs = 4, 50
2155 }
2156 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
2157 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2158 fmt.Fprintf(w, "%v", r.FormValue("echo"))
2159 })).ts
2160
2161 var wg sync.WaitGroup
2162 wg.Add(numReqs)
2163
2164
2165
2166
2167
2168
2169
2170 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
2171 defer SetPendingDialHooks(nil, nil)
2172
2173 c := ts.Client()
2174 reqs := make(chan string)
2175 defer close(reqs)
2176
2177 for i := 0; i < maxProcs*2; i++ {
2178 go func() {
2179 for req := range reqs {
2180 res, err := c.Get(ts.URL + "/?echo=" + req)
2181 if err != nil {
2182 if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
2183
2184
2185 t.Logf("error on req %s: %v", req, err)
2186 t.Logf("(see https://go.dev/issue/52168)")
2187 } else {
2188 t.Errorf("error on req %s: %v", req, err)
2189 }
2190 wg.Done()
2191 continue
2192 }
2193 all, err := io.ReadAll(res.Body)
2194 if err != nil {
2195 t.Errorf("read error on req %s: %v", req, err)
2196 } else if string(all) != req {
2197 t.Errorf("body of req %s = %q; want %q", req, all, req)
2198 }
2199 res.Body.Close()
2200 wg.Done()
2201 }
2202 }()
2203 }
2204 for i := 0; i < numReqs; i++ {
2205 reqs <- fmt.Sprintf("request-%d", i)
2206 }
2207 wg.Wait()
2208 }
2209
2210 func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
2211 func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
2212 mux := NewServeMux()
2213 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2214 io.Copy(w, neverEnding('a'))
2215 })
2216 ts := newClientServerTest(t, mode, mux).ts
2217
2218 connc := make(chan net.Conn, 1)
2219 c := ts.Client()
2220 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2221 conn, err := net.Dial(n, addr)
2222 if err != nil {
2223 return nil, err
2224 }
2225 select {
2226 case connc <- conn:
2227 default:
2228 }
2229 return conn, nil
2230 }
2231
2232 res, err := c.Get(ts.URL + "/get")
2233 if err != nil {
2234 t.Fatalf("Error issuing GET: %v", err)
2235 }
2236 defer res.Body.Close()
2237
2238 conn := <-connc
2239 conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
2240 _, err = io.Copy(io.Discard, res.Body)
2241 if err == nil {
2242 t.Errorf("Unexpected successful copy")
2243 }
2244 }
2245
2246 func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
2247 run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
2248 }
2249 func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
2250 const debug = false
2251 mux := NewServeMux()
2252 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2253 io.Copy(w, neverEnding('a'))
2254 })
2255 mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
2256 defer r.Body.Close()
2257 io.Copy(io.Discard, r.Body)
2258 })
2259 ts := newClientServerTest(t, mode, mux).ts
2260 timeout := 100 * time.Millisecond
2261
2262 c := ts.Client()
2263 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2264 conn, err := net.Dial(n, addr)
2265 if err != nil {
2266 return nil, err
2267 }
2268 conn.SetDeadline(time.Now().Add(timeout))
2269 if debug {
2270 conn = NewLoggingConn("client", conn)
2271 }
2272 return conn, nil
2273 }
2274
2275 getFailed := false
2276 nRuns := 5
2277 if testing.Short() {
2278 nRuns = 1
2279 }
2280 for i := 0; i < nRuns; i++ {
2281 if debug {
2282 println("run", i+1, "of", nRuns)
2283 }
2284 sres, err := c.Get(ts.URL + "/get")
2285 if err != nil {
2286 if !getFailed {
2287
2288 getFailed = true
2289 t.Logf("increasing timeout")
2290 i--
2291 timeout *= 10
2292 continue
2293 }
2294 t.Errorf("Error issuing GET: %v", err)
2295 break
2296 }
2297 req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
2298 _, err = c.Do(req)
2299 if err == nil {
2300 sres.Body.Close()
2301 t.Errorf("Unexpected successful PUT")
2302 break
2303 }
2304 sres.Body.Close()
2305 }
2306 if debug {
2307 println("tests complete; waiting for handlers to finish")
2308 }
2309 ts.Close()
2310 }
2311
2312 func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
2313 func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
2314 if testing.Short() {
2315 t.Skip("skipping timeout test in -short mode")
2316 }
2317
2318 timeout := 2 * time.Millisecond
2319 retry := true
2320 for retry && !t.Failed() {
2321 var srvWG sync.WaitGroup
2322 inHandler := make(chan bool, 1)
2323 mux := NewServeMux()
2324 mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
2325 inHandler <- true
2326 srvWG.Done()
2327 })
2328 mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
2329 inHandler <- true
2330 <-r.Context().Done()
2331 srvWG.Done()
2332 })
2333 ts := newClientServerTest(t, mode, mux).ts
2334
2335 c := ts.Client()
2336 c.Transport.(*Transport).ResponseHeaderTimeout = timeout
2337
2338 retry = false
2339 srvWG.Add(3)
2340 tests := []struct {
2341 path string
2342 wantTimeout bool
2343 }{
2344 {path: "/fast"},
2345 {path: "/slow", wantTimeout: true},
2346 {path: "/fast"},
2347 }
2348 for i, tt := range tests {
2349 req, _ := NewRequest("GET", ts.URL+tt.path, nil)
2350 req = req.WithT(t)
2351 res, err := c.Do(req)
2352 <-inHandler
2353 if err != nil {
2354 uerr, ok := err.(*url.Error)
2355 if !ok {
2356 t.Errorf("error is not a url.Error; got: %#v", err)
2357 continue
2358 }
2359 nerr, ok := uerr.Err.(net.Error)
2360 if !ok {
2361 t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
2362 continue
2363 }
2364 if !nerr.Timeout() {
2365 t.Errorf("want timeout error; got: %q", nerr)
2366 continue
2367 }
2368 if !tt.wantTimeout {
2369 if !retry {
2370
2371 t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
2372 timeout *= 2
2373 retry = true
2374 }
2375 }
2376 if !strings.Contains(err.Error(), "timeout awaiting response headers") {
2377 t.Errorf("%d. unexpected error: %v", i, err)
2378 }
2379 continue
2380 }
2381 if tt.wantTimeout {
2382 t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
2383 continue
2384 }
2385 if res.StatusCode != 200 {
2386 t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
2387 }
2388 }
2389
2390 srvWG.Wait()
2391 ts.Close()
2392 }
2393 }
2394
2395 func TestTransportCancelRequest(t *testing.T) {
2396 run(t, testTransportCancelRequest, []testMode{http1Mode})
2397 }
2398 func testTransportCancelRequest(t *testing.T, mode testMode) {
2399 if testing.Short() {
2400 t.Skip("skipping test in -short mode")
2401 }
2402
2403 const msg = "Hello"
2404 unblockc := make(chan bool)
2405 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2406 io.WriteString(w, msg)
2407 w.(Flusher).Flush()
2408 <-unblockc
2409 })).ts
2410 defer close(unblockc)
2411
2412 c := ts.Client()
2413 tr := c.Transport.(*Transport)
2414
2415 req, _ := NewRequest("GET", ts.URL, nil)
2416 res, err := c.Do(req)
2417 if err != nil {
2418 t.Fatal(err)
2419 }
2420 body := make([]byte, len(msg))
2421 n, _ := io.ReadFull(res.Body, body)
2422 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2423 t.Errorf("Body = %q; want %q", body[:n], msg)
2424 }
2425 tr.CancelRequest(req)
2426
2427 tail, err := io.ReadAll(res.Body)
2428 res.Body.Close()
2429 if err != ExportErrRequestCanceled {
2430 t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
2431 } else if len(tail) > 0 {
2432 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2433 }
2434
2435
2436
2437 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2438 n := tr.NumPendingRequestsForTesting()
2439 if n > 0 {
2440 if d > 0 {
2441 t.Logf("pending requests = %d after %v (want 0)", n, d)
2442 }
2443 }
2444 return true
2445 })
2446 }
2447
2448 func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) {
2449 if testing.Short() {
2450 t.Skip("skipping test in -short mode")
2451 }
2452 unblockc := make(chan bool)
2453 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2454 <-unblockc
2455 })).ts
2456 defer close(unblockc)
2457
2458 c := ts.Client()
2459 tr := c.Transport.(*Transport)
2460
2461 donec := make(chan bool)
2462 req, _ := NewRequest("GET", ts.URL, body)
2463 go func() {
2464 defer close(donec)
2465 c.Do(req)
2466 }()
2467
2468 unblockc <- true
2469 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2470 tr.CancelRequest(req)
2471 select {
2472 case <-donec:
2473 return true
2474 default:
2475 if d > 0 {
2476 t.Logf("Do of canceled request has not returned after %v", d)
2477 }
2478 return false
2479 }
2480 })
2481 }
2482
2483 func TestTransportCancelRequestInDo(t *testing.T) {
2484 run(t, func(t *testing.T, mode testMode) {
2485 testTransportCancelRequestInDo(t, mode, nil)
2486 }, []testMode{http1Mode})
2487 }
2488
2489 func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
2490 run(t, func(t *testing.T, mode testMode) {
2491 testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0}))
2492 }, []testMode{http1Mode})
2493 }
2494
2495 func TestTransportCancelRequestInDial(t *testing.T) {
2496 defer afterTest(t)
2497 if testing.Short() {
2498 t.Skip("skipping test in -short mode")
2499 }
2500 var logbuf strings.Builder
2501 eventLog := log.New(&logbuf, "", 0)
2502
2503 unblockDial := make(chan bool)
2504 defer close(unblockDial)
2505
2506 inDial := make(chan bool)
2507 tr := &Transport{
2508 Dial: func(network, addr string) (net.Conn, error) {
2509 eventLog.Println("dial: blocking")
2510 if !<-inDial {
2511 return nil, errors.New("main Test goroutine exited")
2512 }
2513 <-unblockDial
2514 return nil, errors.New("nope")
2515 },
2516 }
2517 cl := &Client{Transport: tr}
2518 gotres := make(chan bool)
2519 req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
2520 go func() {
2521 _, err := cl.Do(req)
2522 eventLog.Printf("Get = %v", err)
2523 gotres <- true
2524 }()
2525
2526 inDial <- true
2527
2528 eventLog.Printf("canceling")
2529 tr.CancelRequest(req)
2530 tr.CancelRequest(req)
2531
2532 if d, ok := t.Deadline(); ok {
2533
2534
2535 timeout := time.Until(d) * 19 / 20
2536 timer := time.AfterFunc(timeout, func() {
2537 panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
2538 })
2539 defer timer.Stop()
2540 }
2541 <-gotres
2542
2543 got := logbuf.String()
2544 want := `dial: blocking
2545 canceling
2546 Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection
2547 `
2548 if got != want {
2549 t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
2550 }
2551 }
2552
2553 func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) }
2554 func testCancelRequestWithChannel(t *testing.T, mode testMode) {
2555 if testing.Short() {
2556 t.Skip("skipping test in -short mode")
2557 }
2558
2559 const msg = "Hello"
2560 unblockc := make(chan struct{})
2561 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2562 io.WriteString(w, msg)
2563 w.(Flusher).Flush()
2564 <-unblockc
2565 })).ts
2566 defer close(unblockc)
2567
2568 c := ts.Client()
2569 tr := c.Transport.(*Transport)
2570
2571 req, _ := NewRequest("GET", ts.URL, nil)
2572 cancel := make(chan struct{})
2573 req.Cancel = cancel
2574
2575 res, err := c.Do(req)
2576 if err != nil {
2577 t.Fatal(err)
2578 }
2579 body := make([]byte, len(msg))
2580 n, _ := io.ReadFull(res.Body, body)
2581 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2582 t.Errorf("Body = %q; want %q", body[:n], msg)
2583 }
2584 close(cancel)
2585
2586 tail, err := io.ReadAll(res.Body)
2587 res.Body.Close()
2588 if err != ExportErrRequestCanceled {
2589 t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
2590 } else if len(tail) > 0 {
2591 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2592 }
2593
2594
2595
2596 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2597 n := tr.NumPendingRequestsForTesting()
2598 if n > 0 {
2599 if d > 0 {
2600 t.Logf("pending requests = %d after %v (want 0)", n, d)
2601 }
2602 }
2603 return true
2604 })
2605 }
2606
2607 func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) {
2608 run(t, func(t *testing.T, mode testMode) {
2609 testCancelRequestWithChannelBeforeDo(t, mode, false)
2610 })
2611 }
2612 func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) {
2613 run(t, func(t *testing.T, mode testMode) {
2614 testCancelRequestWithChannelBeforeDo(t, mode, true)
2615 })
2616 }
2617 func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) {
2618 unblockc := make(chan bool)
2619 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2620 <-unblockc
2621 })).ts
2622 defer close(unblockc)
2623
2624 c := ts.Client()
2625
2626 req, _ := NewRequest("GET", ts.URL, nil)
2627 if withCtx {
2628 ctx, cancel := context.WithCancel(context.Background())
2629 cancel()
2630 req = req.WithContext(ctx)
2631 } else {
2632 ch := make(chan struct{})
2633 req.Cancel = ch
2634 close(ch)
2635 }
2636
2637 _, err := c.Do(req)
2638 if ue, ok := err.(*url.Error); ok {
2639 err = ue.Err
2640 }
2641 if withCtx {
2642 if err != context.Canceled {
2643 t.Errorf("Do error = %v; want %v", err, context.Canceled)
2644 }
2645 } else {
2646 if err == nil || !strings.Contains(err.Error(), "canceled") {
2647 t.Errorf("Do error = %v; want cancellation", err)
2648 }
2649 }
2650 }
2651
2652
2653 func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
2654 defer afterTest(t)
2655
2656 serverConnCh := make(chan net.Conn, 1)
2657 tr := &Transport{
2658 Dial: func(network, addr string) (net.Conn, error) {
2659 cc, sc := net.Pipe()
2660 serverConnCh <- sc
2661 return cc, nil
2662 },
2663 }
2664 defer tr.CloseIdleConnections()
2665 errc := make(chan error, 1)
2666 req, _ := NewRequest("GET", "http://example.com/", nil)
2667 go func() {
2668 _, err := tr.RoundTrip(req)
2669 errc <- err
2670 }()
2671
2672 sc := <-serverConnCh
2673 verb := make([]byte, 3)
2674 if _, err := io.ReadFull(sc, verb); err != nil {
2675 t.Errorf("Error reading HTTP verb from server: %v", err)
2676 }
2677 if string(verb) != "GET" {
2678 t.Errorf("server received %q; want GET", verb)
2679 }
2680 defer sc.Close()
2681
2682 tr.CancelRequest(req)
2683
2684 err := <-errc
2685 if err == nil {
2686 t.Fatalf("unexpected success from RoundTrip")
2687 }
2688 if err != ExportErrRequestCanceled {
2689 t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err)
2690 }
2691 }
2692
2693
2694
2695
2696 func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
2697 func testTransportCloseResponseBody(t *testing.T, mode testMode) {
2698 writeErr := make(chan error, 1)
2699 msg := []byte("young\n")
2700 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2701 for {
2702 _, err := w.Write(msg)
2703 if err != nil {
2704 writeErr <- err
2705 return
2706 }
2707 w.(Flusher).Flush()
2708 }
2709 })).ts
2710
2711 c := ts.Client()
2712 tr := c.Transport.(*Transport)
2713
2714 req, _ := NewRequest("GET", ts.URL, nil)
2715 defer tr.CancelRequest(req)
2716
2717 res, err := c.Do(req)
2718 if err != nil {
2719 t.Fatal(err)
2720 }
2721
2722 const repeats = 3
2723 buf := make([]byte, len(msg)*repeats)
2724 want := bytes.Repeat(msg, repeats)
2725
2726 _, err = io.ReadFull(res.Body, buf)
2727 if err != nil {
2728 t.Fatal(err)
2729 }
2730 if !bytes.Equal(buf, want) {
2731 t.Fatalf("read %q; want %q", buf, want)
2732 }
2733
2734 if err := res.Body.Close(); err != nil {
2735 t.Errorf("Close = %v", err)
2736 }
2737
2738 if err := <-writeErr; err == nil {
2739 t.Errorf("expected non-nil write error")
2740 }
2741 }
2742
2743 type fooProto struct{}
2744
2745 func (fooProto) RoundTrip(req *Request) (*Response, error) {
2746 res := &Response{
2747 Status: "200 OK",
2748 StatusCode: 200,
2749 Header: make(Header),
2750 Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
2751 }
2752 return res, nil
2753 }
2754
2755 func TestTransportAltProto(t *testing.T) {
2756 defer afterTest(t)
2757 tr := &Transport{}
2758 c := &Client{Transport: tr}
2759 tr.RegisterProtocol("foo", fooProto{})
2760 res, err := c.Get("foo://bar.com/path")
2761 if err != nil {
2762 t.Fatal(err)
2763 }
2764 bodyb, err := io.ReadAll(res.Body)
2765 if err != nil {
2766 t.Fatal(err)
2767 }
2768 body := string(bodyb)
2769 if e := "You wanted foo://bar.com/path"; body != e {
2770 t.Errorf("got response %q, want %q", body, e)
2771 }
2772 }
2773
2774 func TestTransportNoHost(t *testing.T) {
2775 defer afterTest(t)
2776 tr := &Transport{}
2777 _, err := tr.RoundTrip(&Request{
2778 Header: make(Header),
2779 URL: &url.URL{
2780 Scheme: "http",
2781 },
2782 })
2783 want := "http: no Host in request URL"
2784 if got := fmt.Sprint(err); got != want {
2785 t.Errorf("error = %v; want %q", err, want)
2786 }
2787 }
2788
2789
2790 func TestTransportEmptyMethod(t *testing.T) {
2791 req, _ := NewRequest("GET", "http://foo.com/", nil)
2792 req.Method = ""
2793 got, err := httputil.DumpRequestOut(req, false)
2794 if err != nil {
2795 t.Fatal(err)
2796 }
2797 if !strings.Contains(string(got), "GET ") {
2798 t.Fatalf("expected substring 'GET '; got: %s", got)
2799 }
2800 }
2801
2802 func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
2803 func testTransportSocketLateBinding(t *testing.T, mode testMode) {
2804 mux := NewServeMux()
2805 fooGate := make(chan bool, 1)
2806 mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
2807 w.Header().Set("foo-ipport", r.RemoteAddr)
2808 w.(Flusher).Flush()
2809 <-fooGate
2810 })
2811 mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
2812 w.Header().Set("bar-ipport", r.RemoteAddr)
2813 })
2814 ts := newClientServerTest(t, mode, mux).ts
2815
2816 dialGate := make(chan bool, 1)
2817 dialing := make(chan bool)
2818 c := ts.Client()
2819 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2820 for {
2821 select {
2822 case ok := <-dialGate:
2823 if !ok {
2824 return nil, errors.New("manually closed")
2825 }
2826 return net.Dial(n, addr)
2827 case dialing <- true:
2828 }
2829 }
2830 }
2831 defer close(dialGate)
2832
2833 dialGate <- true
2834 fooRes, err := c.Get(ts.URL + "/foo")
2835 if err != nil {
2836 t.Fatal(err)
2837 }
2838 fooAddr := fooRes.Header.Get("foo-ipport")
2839 if fooAddr == "" {
2840 t.Fatal("No addr on /foo request")
2841 }
2842
2843 fooDone := make(chan struct{})
2844 go func() {
2845
2846
2847
2848
2849 if mode == http2Mode {
2850
2851
2852
2853
2854 select {
2855 case <-dialing:
2856 t.Errorf("unexpected second Dial in HTTP/2 mode")
2857 case <-time.After(10 * time.Millisecond):
2858 }
2859 } else {
2860 <-dialing
2861 }
2862 fooGate <- true
2863 io.Copy(io.Discard, fooRes.Body)
2864 fooRes.Body.Close()
2865 close(fooDone)
2866 }()
2867 defer func() {
2868 <-fooDone
2869 }()
2870
2871 barRes, err := c.Get(ts.URL + "/bar")
2872 if err != nil {
2873 t.Fatal(err)
2874 }
2875 barAddr := barRes.Header.Get("bar-ipport")
2876 if barAddr != fooAddr {
2877 t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
2878 }
2879 barRes.Body.Close()
2880 }
2881
2882
2883 func TestTransportReading100Continue(t *testing.T) {
2884 defer afterTest(t)
2885
2886 const numReqs = 5
2887 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
2888 reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
2889
2890 send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
2891 defer w.Close()
2892 defer r.Close()
2893 br := bufio.NewReader(r)
2894 n := 0
2895 for {
2896 n++
2897 req, err := ReadRequest(br)
2898 if err == io.EOF {
2899 return
2900 }
2901 if err != nil {
2902 t.Error(err)
2903 return
2904 }
2905 slurp, err := io.ReadAll(req.Body)
2906 if err != nil {
2907 t.Errorf("Server request body slurp: %v", err)
2908 return
2909 }
2910 id := req.Header.Get("Request-Id")
2911 resCode := req.Header.Get("X-Want-Response-Code")
2912 if resCode == "" {
2913 resCode = "100 Continue"
2914 if string(slurp) != reqBody(n) {
2915 t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
2916 }
2917 }
2918 body := fmt.Sprintf("Response number %d", n)
2919 v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
2920 Date: Thu, 28 Feb 2013 17:55:41 GMT
2921
2922 HTTP/1.1 200 OK
2923 Content-Type: text/html
2924 Echo-Request-Id: %s
2925 Content-Length: %d
2926
2927 %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
2928 w.Write(v)
2929 if id == reqID(numReqs) {
2930 return
2931 }
2932 }
2933
2934 }
2935
2936 tr := &Transport{
2937 Dial: func(n, addr string) (net.Conn, error) {
2938 sr, sw := io.Pipe()
2939 cr, cw := io.Pipe()
2940 conn := &rwTestConn{
2941 Reader: cr,
2942 Writer: sw,
2943 closeFunc: func() error {
2944 sw.Close()
2945 cw.Close()
2946 return nil
2947 },
2948 }
2949 go send100Response(cw, sr)
2950 return conn, nil
2951 },
2952 DisableKeepAlives: false,
2953 }
2954 defer tr.CloseIdleConnections()
2955 c := &Client{Transport: tr}
2956
2957 testResponse := func(req *Request, name string, wantCode int) {
2958 t.Helper()
2959 res, err := c.Do(req)
2960 if err != nil {
2961 t.Fatalf("%s: Do: %v", name, err)
2962 }
2963 if res.StatusCode != wantCode {
2964 t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
2965 }
2966 if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
2967 t.Errorf("%s: response id %q != request id %q", name, idBack, id)
2968 }
2969 _, err = io.ReadAll(res.Body)
2970 if err != nil {
2971 t.Fatalf("%s: Slurp error: %v", name, err)
2972 }
2973 }
2974
2975
2976 for i := 1; i <= numReqs; i++ {
2977 req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
2978 req.Header.Set("Request-Id", reqID(i))
2979 testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
2980 }
2981 }
2982
2983
2984
2985 func TestTransportIgnore1xxResponses(t *testing.T) {
2986 run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
2987 }
2988 func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
2989 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2990 conn, buf, _ := w.(Hijacker).Hijack()
2991 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
2992 buf.Flush()
2993 conn.Close()
2994 }))
2995 cst.tr.DisableKeepAlives = true
2996
2997 var got strings.Builder
2998
2999 req, _ := NewRequest("GET", cst.ts.URL, nil)
3000 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3001 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3002 fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
3003 return nil
3004 },
3005 }))
3006 res, err := cst.c.Do(req)
3007 if err != nil {
3008 t.Fatal(err)
3009 }
3010 defer res.Body.Close()
3011
3012 res.Write(&got)
3013 want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
3014 if got.String() != want {
3015 t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
3016 }
3017 }
3018
3019 func TestTransportLimits1xxResponses(t *testing.T) {
3020 run(t, testTransportLimits1xxResponses, []testMode{http1Mode})
3021 }
3022 func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
3023 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3024 conn, buf, _ := w.(Hijacker).Hijack()
3025 for i := 0; i < 10; i++ {
3026 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n"))
3027 }
3028 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3029 buf.Flush()
3030 conn.Close()
3031 }))
3032 cst.tr.DisableKeepAlives = true
3033
3034 res, err := cst.c.Get(cst.ts.URL)
3035 if res != nil {
3036 defer res.Body.Close()
3037 }
3038 got := fmt.Sprint(err)
3039 wantSub := "too many 1xx informational responses"
3040 if !strings.Contains(got, wantSub) {
3041 t.Errorf("Get error = %v; want substring %q", err, wantSub)
3042 }
3043 }
3044
3045
3046
3047 func TestTransportTreat101Terminal(t *testing.T) {
3048 run(t, testTransportTreat101Terminal, []testMode{http1Mode})
3049 }
3050 func testTransportTreat101Terminal(t *testing.T, mode testMode) {
3051 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3052 conn, buf, _ := w.(Hijacker).Hijack()
3053 buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
3054 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3055 buf.Flush()
3056 conn.Close()
3057 }))
3058 res, err := cst.c.Get(cst.ts.URL)
3059 if err != nil {
3060 t.Fatal(err)
3061 }
3062 defer res.Body.Close()
3063 if res.StatusCode != StatusSwitchingProtocols {
3064 t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
3065 }
3066 }
3067
3068 type proxyFromEnvTest struct {
3069 req string
3070
3071 env string
3072 httpsenv string
3073 noenv string
3074 reqmeth string
3075
3076 want string
3077 wanterr error
3078 }
3079
3080 func (t proxyFromEnvTest) String() string {
3081 var buf strings.Builder
3082 space := func() {
3083 if buf.Len() > 0 {
3084 buf.WriteByte(' ')
3085 }
3086 }
3087 if t.env != "" {
3088 fmt.Fprintf(&buf, "http_proxy=%q", t.env)
3089 }
3090 if t.httpsenv != "" {
3091 space()
3092 fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
3093 }
3094 if t.noenv != "" {
3095 space()
3096 fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
3097 }
3098 if t.reqmeth != "" {
3099 space()
3100 fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
3101 }
3102 req := "http://example.com"
3103 if t.req != "" {
3104 req = t.req
3105 }
3106 space()
3107 fmt.Fprintf(&buf, "req=%q", req)
3108 return strings.TrimSpace(buf.String())
3109 }
3110
3111 var proxyFromEnvTests = []proxyFromEnvTest{
3112 {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3113 {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
3114 {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
3115 {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
3116 {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3117 {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
3118 {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
3119
3120
3121 {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
3122
3123 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
3124 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
3125
3126
3127
3128 {env: "http://10.1.2.3:8080", reqmeth: "POST",
3129 want: "<nil>",
3130 wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
3131
3132 {want: "<nil>"},
3133
3134 {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
3135 {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3136 {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3137 {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
3138 {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3139 }
3140
3141 func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
3142 t.Helper()
3143 reqURL := tt.req
3144 if reqURL == "" {
3145 reqURL = "http://example.com"
3146 }
3147 req, _ := NewRequest("GET", reqURL, nil)
3148 url, err := proxyForRequest(req)
3149 if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
3150 t.Errorf("%v: got error = %q, want %q", tt, g, e)
3151 return
3152 }
3153 if got := fmt.Sprintf("%s", url); got != tt.want {
3154 t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
3155 }
3156 }
3157
3158 func TestProxyFromEnvironment(t *testing.T) {
3159 ResetProxyEnv()
3160 defer ResetProxyEnv()
3161 for _, tt := range proxyFromEnvTests {
3162 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3163 os.Setenv("HTTP_PROXY", tt.env)
3164 os.Setenv("HTTPS_PROXY", tt.httpsenv)
3165 os.Setenv("NO_PROXY", tt.noenv)
3166 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3167 ResetCachedEnvironment()
3168 return ProxyFromEnvironment(req)
3169 })
3170 }
3171 }
3172
3173 func TestProxyFromEnvironmentLowerCase(t *testing.T) {
3174 ResetProxyEnv()
3175 defer ResetProxyEnv()
3176 for _, tt := range proxyFromEnvTests {
3177 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3178 os.Setenv("http_proxy", tt.env)
3179 os.Setenv("https_proxy", tt.httpsenv)
3180 os.Setenv("no_proxy", tt.noenv)
3181 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3182 ResetCachedEnvironment()
3183 return ProxyFromEnvironment(req)
3184 })
3185 }
3186 }
3187
3188 func TestIdleConnChannelLeak(t *testing.T) {
3189 run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
3190 }
3191 func testIdleConnChannelLeak(t *testing.T, mode testMode) {
3192
3193 var mu sync.Mutex
3194 var n int
3195
3196 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3197 mu.Lock()
3198 n++
3199 mu.Unlock()
3200 })).ts
3201
3202 const nReqs = 5
3203 didRead := make(chan bool, nReqs)
3204 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
3205 defer SetReadLoopBeforeNextReadHook(nil)
3206
3207 c := ts.Client()
3208 tr := c.Transport.(*Transport)
3209 tr.Dial = func(netw, addr string) (net.Conn, error) {
3210 return net.Dial(netw, ts.Listener.Addr().String())
3211 }
3212
3213
3214 for _, disableKeep := range []bool{true, false} {
3215 tr.DisableKeepAlives = disableKeep
3216 for i := 0; i < nReqs; i++ {
3217 _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
3218 if err != nil {
3219 t.Fatal(err)
3220 }
3221
3222
3223
3224
3225
3226 }
3227
3228
3229
3230
3231
3232
3233
3234 for i := 0; i < nReqs; i++ {
3235 <-didRead
3236 }
3237
3238 if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
3239 t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
3240 }
3241 }
3242 }
3243
3244
3245
3246
3247 func TestTransportClosesRequestBody(t *testing.T) {
3248 run(t, testTransportClosesRequestBody, []testMode{http1Mode})
3249 }
3250 func testTransportClosesRequestBody(t *testing.T, mode testMode) {
3251 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3252 io.Copy(io.Discard, r.Body)
3253 })).ts
3254
3255 c := ts.Client()
3256
3257 closes := 0
3258
3259 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
3260 if err != nil {
3261 t.Fatal(err)
3262 }
3263 res.Body.Close()
3264 if closes != 1 {
3265 t.Errorf("closes = %d; want 1", closes)
3266 }
3267 }
3268
3269 func TestTransportTLSHandshakeTimeout(t *testing.T) {
3270 defer afterTest(t)
3271 if testing.Short() {
3272 t.Skip("skipping in short mode")
3273 }
3274 ln := newLocalListener(t)
3275 defer ln.Close()
3276 testdonec := make(chan struct{})
3277 defer close(testdonec)
3278
3279 go func() {
3280 c, err := ln.Accept()
3281 if err != nil {
3282 t.Error(err)
3283 return
3284 }
3285 <-testdonec
3286 c.Close()
3287 }()
3288
3289 tr := &Transport{
3290 Dial: func(_, _ string) (net.Conn, error) {
3291 return net.Dial("tcp", ln.Addr().String())
3292 },
3293 TLSHandshakeTimeout: 250 * time.Millisecond,
3294 }
3295 cl := &Client{Transport: tr}
3296 _, err := cl.Get("https://dummy.tld/")
3297 if err == nil {
3298 t.Error("expected error")
3299 return
3300 }
3301 ue, ok := err.(*url.Error)
3302 if !ok {
3303 t.Errorf("expected url.Error; got %#v", err)
3304 return
3305 }
3306 ne, ok := ue.Err.(net.Error)
3307 if !ok {
3308 t.Errorf("expected net.Error; got %#v", err)
3309 return
3310 }
3311 if !ne.Timeout() {
3312 t.Errorf("expected timeout error; got %v", err)
3313 }
3314 if !strings.Contains(err.Error(), "handshake timeout") {
3315 t.Errorf("expected 'handshake timeout' in error; got %v", err)
3316 }
3317 }
3318
3319
3320 func TestTLSServerClosesConnection(t *testing.T) {
3321 run(t, testTLSServerClosesConnection, []testMode{https1Mode})
3322 }
3323 func testTLSServerClosesConnection(t *testing.T, mode testMode) {
3324 closedc := make(chan bool, 1)
3325 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3326 if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
3327 conn, _, _ := w.(Hijacker).Hijack()
3328 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3329 conn.Close()
3330 closedc <- true
3331 return
3332 }
3333 fmt.Fprintf(w, "hello")
3334 })).ts
3335
3336 c := ts.Client()
3337 tr := c.Transport.(*Transport)
3338
3339 var nSuccess = 0
3340 var errs []error
3341 const trials = 20
3342 for i := 0; i < trials; i++ {
3343 tr.CloseIdleConnections()
3344 res, err := c.Get(ts.URL + "/keep-alive-then-die")
3345 if err != nil {
3346 t.Fatal(err)
3347 }
3348 <-closedc
3349 slurp, err := io.ReadAll(res.Body)
3350 if err != nil {
3351 t.Fatal(err)
3352 }
3353 if string(slurp) != "foo" {
3354 t.Errorf("Got %q, want foo", slurp)
3355 }
3356
3357
3358
3359 res, err = c.Get(ts.URL + "/")
3360 if err != nil {
3361 errs = append(errs, err)
3362 continue
3363 }
3364 slurp, err = io.ReadAll(res.Body)
3365 if err != nil {
3366 errs = append(errs, err)
3367 continue
3368 }
3369 nSuccess++
3370 }
3371 if nSuccess > 0 {
3372 t.Logf("successes = %d of %d", nSuccess, trials)
3373 } else {
3374 t.Errorf("All runs failed:")
3375 }
3376 for _, err := range errs {
3377 t.Logf(" err: %v", err)
3378 }
3379 }
3380
3381
3382
3383
3384 type byteFromChanReader chan byte
3385
3386 func (c byteFromChanReader) Read(p []byte) (n int, err error) {
3387 if len(p) == 0 {
3388 return
3389 }
3390 b, ok := <-c
3391 if !ok {
3392 return 0, io.EOF
3393 }
3394 p[0] = b
3395 return 1, nil
3396 }
3397
3398
3399
3400
3401
3402
3403
3404 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
3405 run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
3406 }
3407 func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
3408 defer func(d time.Duration) {
3409 *MaxWriteWaitBeforeConnReuse = d
3410 }(*MaxWriteWaitBeforeConnReuse)
3411 *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
3412 var sconn struct {
3413 sync.Mutex
3414 c net.Conn
3415 }
3416 var getOkay bool
3417 closeConn := func() {
3418 sconn.Lock()
3419 defer sconn.Unlock()
3420 if sconn.c != nil {
3421 sconn.c.Close()
3422 sconn.c = nil
3423 if !getOkay {
3424 t.Logf("Closed server connection")
3425 }
3426 }
3427 }
3428 defer closeConn()
3429
3430 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3431 if r.Method == "GET" {
3432 io.WriteString(w, "bar")
3433 return
3434 }
3435 conn, _, _ := w.(Hijacker).Hijack()
3436 sconn.Lock()
3437 sconn.c = conn
3438 sconn.Unlock()
3439 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3440 go io.Copy(io.Discard, conn)
3441 })).ts
3442 c := ts.Client()
3443
3444 const bodySize = 256 << 10
3445 finalBit := make(byteFromChanReader, 1)
3446 req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
3447 req.ContentLength = bodySize
3448 res, err := c.Do(req)
3449 if err := wantBody(res, err, "foo"); err != nil {
3450 t.Errorf("POST response: %v", err)
3451 }
3452
3453 res, err = c.Get(ts.URL)
3454 if err := wantBody(res, err, "bar"); err != nil {
3455 t.Errorf("GET response: %v", err)
3456 return
3457 }
3458 getOkay = true
3459 finalBit <- 'x'
3460 close(finalBit)
3461 }
3462
3463
3464
3465 func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
3466 func testTransportIssue10457(t *testing.T, mode testMode) {
3467 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3468
3469
3470
3471
3472
3473 conn, _, _ := w.(Hijacker).Hijack()
3474 conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n"))
3475 conn.Close()
3476 })).ts
3477 c := ts.Client()
3478
3479 res, err := c.Get(ts.URL)
3480 if err != nil {
3481 t.Fatalf("Get: %v", err)
3482 }
3483 defer res.Body.Close()
3484
3485
3486
3487
3488 if got, want := res.Header.Get("Foo"), "Bar"; got != want {
3489 t.Errorf("Foo header = %q; want %q", got, want)
3490 }
3491 }
3492
3493 type closerFunc func() error
3494
3495 func (f closerFunc) Close() error { return f() }
3496
3497 type writerFuncConn struct {
3498 net.Conn
3499 write func(p []byte) (n int, err error)
3500 }
3501
3502 func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
3503
3504
3505
3506
3507
3508
3509
3510
3511
3512
3513
3514
3515
3516 func TestRetryRequestsOnError(t *testing.T) {
3517 run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
3518 }
3519 func testRetryRequestsOnError(t *testing.T, mode testMode) {
3520 newRequest := func(method, urlStr string, body io.Reader) *Request {
3521 req, err := NewRequest(method, urlStr, body)
3522 if err != nil {
3523 t.Fatal(err)
3524 }
3525 return req
3526 }
3527
3528 testCases := []struct {
3529 name string
3530 failureN int
3531 failureErr error
3532
3533
3534
3535 req func() *Request
3536 reqString string
3537 }{
3538 {
3539 name: "IdempotentNoBodySomeWritten",
3540
3541
3542 failureN: 1,
3543
3544 failureErr: ExportErrServerClosedIdle,
3545 req: func() *Request {
3546 return newRequest("GET", "http://fake.golang", nil)
3547 },
3548 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3549 },
3550 {
3551 name: "IdempotentGetBodySomeWritten",
3552
3553
3554 failureN: 1,
3555
3556 failureErr: ExportErrServerClosedIdle,
3557 req: func() *Request {
3558 return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
3559 },
3560 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3561 },
3562 {
3563 name: "NothingWrittenNoBody",
3564
3565
3566 failureN: 0,
3567 failureErr: errors.New("second write fails"),
3568 req: func() *Request {
3569 return newRequest("DELETE", "http://fake.golang", nil)
3570 },
3571 reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3572 },
3573 {
3574 name: "NothingWrittenGetBody",
3575
3576
3577 failureN: 0,
3578 failureErr: errors.New("second write fails"),
3579
3580
3581 req: func() *Request {
3582 return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
3583 },
3584 reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3585 },
3586 }
3587
3588 for _, tc := range testCases {
3589 t.Run(tc.name, func(t *testing.T) {
3590 var (
3591 mu sync.Mutex
3592 logbuf strings.Builder
3593 )
3594 logf := func(format string, args ...any) {
3595 mu.Lock()
3596 defer mu.Unlock()
3597 fmt.Fprintf(&logbuf, format, args...)
3598 logbuf.WriteByte('\n')
3599 }
3600
3601 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3602 logf("Handler")
3603 w.Header().Set("X-Status", "ok")
3604 })).ts
3605
3606 var writeNumAtomic int32
3607 c := ts.Client()
3608 c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
3609 logf("Dial")
3610 c, err := net.Dial(network, ts.Listener.Addr().String())
3611 if err != nil {
3612 logf("Dial error: %v", err)
3613 return nil, err
3614 }
3615 return &writerFuncConn{
3616 Conn: c,
3617 write: func(p []byte) (n int, err error) {
3618 if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
3619 logf("intentional write failure")
3620 return tc.failureN, tc.failureErr
3621 }
3622 logf("Write(%q)", p)
3623 return c.Write(p)
3624 },
3625 }, nil
3626 }
3627
3628 SetRoundTripRetried(func() {
3629 logf("Retried.")
3630 })
3631 defer SetRoundTripRetried(nil)
3632
3633 for i := 0; i < 3; i++ {
3634 t0 := time.Now()
3635 req := tc.req()
3636 res, err := c.Do(req)
3637 if err != nil {
3638 if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
3639 mu.Lock()
3640 got := logbuf.String()
3641 mu.Unlock()
3642 t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
3643 }
3644 t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
3645 }
3646 res.Body.Close()
3647 if res.Request != req {
3648 t.Errorf("Response.Request != original request; want identical Request")
3649 }
3650 }
3651
3652 mu.Lock()
3653 got := logbuf.String()
3654 mu.Unlock()
3655 want := fmt.Sprintf(`Dial
3656 Write("%s")
3657 Handler
3658 intentional write failure
3659 Retried.
3660 Dial
3661 Write("%s")
3662 Handler
3663 Write("%s")
3664 Handler
3665 `, tc.reqString, tc.reqString, tc.reqString)
3666 if got != want {
3667 t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
3668 }
3669 })
3670 }
3671 }
3672
3673
3674 func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
3675 func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
3676 readBody := make(chan error, 1)
3677 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3678 _, err := io.ReadAll(r.Body)
3679 readBody <- err
3680 })).ts
3681 c := ts.Client()
3682 fakeErr := errors.New("fake error")
3683 didClose := make(chan bool, 1)
3684 req, _ := NewRequest("POST", ts.URL, struct {
3685 io.Reader
3686 io.Closer
3687 }{
3688 io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
3689 closerFunc(func() error {
3690 select {
3691 case didClose <- true:
3692 default:
3693 }
3694 return nil
3695 }),
3696 })
3697 res, err := c.Do(req)
3698 if res != nil {
3699 defer res.Body.Close()
3700 }
3701 if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
3702 t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
3703 }
3704 if err := <-readBody; err == nil {
3705 t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
3706 }
3707 select {
3708 case <-didClose:
3709 default:
3710 t.Errorf("didn't see Body.Close")
3711 }
3712 }
3713
3714 func TestTransportDialTLS(t *testing.T) {
3715 run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
3716 }
3717 func testTransportDialTLS(t *testing.T, mode testMode) {
3718 var mu sync.Mutex
3719 var gotReq, didDial bool
3720
3721 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3722 mu.Lock()
3723 gotReq = true
3724 mu.Unlock()
3725 })).ts
3726 c := ts.Client()
3727 c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
3728 mu.Lock()
3729 didDial = true
3730 mu.Unlock()
3731 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
3732 if err != nil {
3733 return nil, err
3734 }
3735 return c, c.Handshake()
3736 }
3737
3738 res, err := c.Get(ts.URL)
3739 if err != nil {
3740 t.Fatal(err)
3741 }
3742 res.Body.Close()
3743 mu.Lock()
3744 if !gotReq {
3745 t.Error("didn't get request")
3746 }
3747 if !didDial {
3748 t.Error("didn't use dial hook")
3749 }
3750 }
3751
3752 func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
3753 func testTransportDialContext(t *testing.T, mode testMode) {
3754 var mu sync.Mutex
3755 var gotReq bool
3756 var receivedContext context.Context
3757
3758 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3759 mu.Lock()
3760 gotReq = true
3761 mu.Unlock()
3762 })).ts
3763 c := ts.Client()
3764 c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
3765 mu.Lock()
3766 receivedContext = ctx
3767 mu.Unlock()
3768 return net.Dial(netw, addr)
3769 }
3770
3771 req, err := NewRequest("GET", ts.URL, nil)
3772 if err != nil {
3773 t.Fatal(err)
3774 }
3775 ctx := context.WithValue(context.Background(), "some-key", "some-value")
3776 res, err := c.Do(req.WithContext(ctx))
3777 if err != nil {
3778 t.Fatal(err)
3779 }
3780 res.Body.Close()
3781 mu.Lock()
3782 if !gotReq {
3783 t.Error("didn't get request")
3784 }
3785 if receivedContext != ctx {
3786 t.Error("didn't receive correct context")
3787 }
3788 }
3789
3790 func TestTransportDialTLSContext(t *testing.T) {
3791 run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
3792 }
3793 func testTransportDialTLSContext(t *testing.T, mode testMode) {
3794 var mu sync.Mutex
3795 var gotReq bool
3796 var receivedContext context.Context
3797
3798 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3799 mu.Lock()
3800 gotReq = true
3801 mu.Unlock()
3802 })).ts
3803 c := ts.Client()
3804 c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
3805 mu.Lock()
3806 receivedContext = ctx
3807 mu.Unlock()
3808 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
3809 if err != nil {
3810 return nil, err
3811 }
3812 return c, c.HandshakeContext(ctx)
3813 }
3814
3815 req, err := NewRequest("GET", ts.URL, nil)
3816 if err != nil {
3817 t.Fatal(err)
3818 }
3819 ctx := context.WithValue(context.Background(), "some-key", "some-value")
3820 res, err := c.Do(req.WithContext(ctx))
3821 if err != nil {
3822 t.Fatal(err)
3823 }
3824 res.Body.Close()
3825 mu.Lock()
3826 if !gotReq {
3827 t.Error("didn't get request")
3828 }
3829 if receivedContext != ctx {
3830 t.Error("didn't receive correct context")
3831 }
3832 }
3833
3834
3835
3836 func TestRoundTripReturnsProxyError(t *testing.T) {
3837 badProxy := func(*Request) (*url.URL, error) {
3838 return nil, errors.New("errorMessage")
3839 }
3840
3841 tr := &Transport{Proxy: badProxy}
3842
3843 req, _ := NewRequest("GET", "http://example.com", nil)
3844
3845 _, err := tr.RoundTrip(req)
3846
3847 if err == nil {
3848 t.Error("Expected proxy error to be returned by RoundTrip")
3849 }
3850 }
3851
3852
3853 func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
3854 tr := &Transport{}
3855 wantIdle := func(when string, n int) bool {
3856 got := tr.IdleConnCountForTesting("http", "example.com")
3857 if got == n {
3858 return true
3859 }
3860 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
3861 return false
3862 }
3863 wantIdle("start", 0)
3864 if !tr.PutIdleTestConn("http", "example.com") {
3865 t.Fatal("put failed")
3866 }
3867 if !tr.PutIdleTestConn("http", "example.com") {
3868 t.Fatal("second put failed")
3869 }
3870 wantIdle("after put", 2)
3871 tr.CloseIdleConnections()
3872 if !tr.IsIdleForTesting() {
3873 t.Error("should be idle after CloseIdleConnections")
3874 }
3875 wantIdle("after close idle", 0)
3876 if tr.PutIdleTestConn("http", "example.com") {
3877 t.Fatal("put didn't fail")
3878 }
3879 wantIdle("after second put", 0)
3880
3881 tr.QueueForIdleConnForTesting()
3882 if tr.IsIdleForTesting() {
3883 t.Error("shouldn't be idle after QueueForIdleConnForTesting")
3884 }
3885 if !tr.PutIdleTestConn("http", "example.com") {
3886 t.Fatal("after re-activation")
3887 }
3888 wantIdle("after final put", 1)
3889 }
3890
3891
3892
3893 func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
3894 tr := &Transport{}
3895 wantIdle := func(when string, n int) bool {
3896 got := tr.IdleConnCountForTesting("https", "example.com:443")
3897 if got == n {
3898 return true
3899 }
3900 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
3901 return false
3902 }
3903 wantIdle("start", 0)
3904 alt := funcRoundTripper(func() {})
3905 if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
3906 t.Fatal("put failed")
3907 }
3908 wantIdle("after put", 1)
3909 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3910 GotConn: func(httptrace.GotConnInfo) {
3911
3912 t.Error("GotConn called")
3913 },
3914 })
3915 req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
3916 _, err := tr.RoundTrip(req)
3917 if err != errFakeRoundTrip {
3918 t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
3919 }
3920 wantIdle("after round trip", 1)
3921 }
3922
3923 func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) {
3924 run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode})
3925 }
3926 func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) {
3927 if testing.Short() {
3928 t.Skip("skipping in short mode")
3929 }
3930
3931 timeout := 1 * time.Millisecond
3932 retry := true
3933 for retry {
3934 trFunc := func(tr *Transport) {
3935 tr.MaxConnsPerHost = 1
3936 tr.MaxIdleConnsPerHost = 1
3937 tr.IdleConnTimeout = timeout
3938 }
3939 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc)
3940
3941 retry = false
3942 tooShort := func(err error) bool {
3943 if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
3944 return false
3945 }
3946 if !retry {
3947 t.Helper()
3948 t.Logf("idle conn timeout %v may be too short; retrying with longer", timeout)
3949 timeout *= 2
3950 retry = true
3951 cst.close()
3952 }
3953 return true
3954 }
3955
3956 if _, err := cst.c.Get(cst.ts.URL); err != nil {
3957 if tooShort(err) {
3958 continue
3959 }
3960 t.Fatalf("got error: %s", err)
3961 }
3962
3963 time.Sleep(10 * timeout)
3964 if _, err := cst.c.Get(cst.ts.URL); err != nil {
3965 if tooShort(err) {
3966 continue
3967 }
3968 t.Fatalf("got error: %s", err)
3969 }
3970 }
3971 }
3972
3973
3974
3975
3976
3977 func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
3978 func testTransportRangeAndGzip(t *testing.T, mode testMode) {
3979 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3980 if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
3981 t.Error("Transport advertised gzip support in the Accept header")
3982 }
3983 if r.Header.Get("Range") == "" {
3984 t.Error("no Range in request")
3985 }
3986 })).ts
3987 c := ts.Client()
3988
3989 req, _ := NewRequest("GET", ts.URL, nil)
3990 req.Header.Set("Range", "bytes=7-11")
3991 res, err := c.Do(req)
3992 if err != nil {
3993 t.Fatal(err)
3994 }
3995 res.Body.Close()
3996 }
3997
3998
3999 func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
4000 func testTransportResponseCancelRace(t *testing.T, mode testMode) {
4001 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4002
4003 var b [1024]byte
4004 w.Write(b[:])
4005 })).ts
4006 tr := ts.Client().Transport.(*Transport)
4007
4008 req, err := NewRequest("GET", ts.URL, nil)
4009 if err != nil {
4010 t.Fatal(err)
4011 }
4012 res, err := tr.RoundTrip(req)
4013 if err != nil {
4014 t.Fatal(err)
4015 }
4016
4017
4018
4019 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4020 t.Fatal(err)
4021 }
4022
4023 req2, err := NewRequest("GET", ts.URL, nil)
4024 if err != nil {
4025 t.Fatal(err)
4026 }
4027 tr.CancelRequest(req)
4028 res, err = tr.RoundTrip(req2)
4029 if err != nil {
4030 t.Fatal(err)
4031 }
4032 res.Body.Close()
4033 }
4034
4035
4036 func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
4037 run(t, testTransportContentEncodingCaseInsensitive)
4038 }
4039 func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
4040 for _, ce := range []string{"gzip", "GZIP"} {
4041 ce := ce
4042 t.Run(ce, func(t *testing.T) {
4043 const encodedString = "Hello Gopher"
4044 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4045 w.Header().Set("Content-Encoding", ce)
4046 gz := gzip.NewWriter(w)
4047 gz.Write([]byte(encodedString))
4048 gz.Close()
4049 })).ts
4050
4051 res, err := ts.Client().Get(ts.URL)
4052 if err != nil {
4053 t.Fatal(err)
4054 }
4055
4056 body, err := io.ReadAll(res.Body)
4057 res.Body.Close()
4058 if err != nil {
4059 t.Fatal(err)
4060 }
4061
4062 if string(body) != encodedString {
4063 t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
4064 }
4065 })
4066 }
4067 }
4068
4069 func TestTransportDialCancelRace(t *testing.T) {
4070 run(t, testTransportDialCancelRace, testNotParallel, []testMode{http1Mode})
4071 }
4072 func testTransportDialCancelRace(t *testing.T, mode testMode) {
4073 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
4074 tr := ts.Client().Transport.(*Transport)
4075
4076 req, err := NewRequest("GET", ts.URL, nil)
4077 if err != nil {
4078 t.Fatal(err)
4079 }
4080 SetEnterRoundTripHook(func() {
4081 tr.CancelRequest(req)
4082 })
4083 defer SetEnterRoundTripHook(nil)
4084 res, err := tr.RoundTrip(req)
4085 if err != ExportErrRequestCanceled {
4086 t.Errorf("expected canceled request error; got %v", err)
4087 if err == nil {
4088 res.Body.Close()
4089 }
4090 }
4091 }
4092
4093
4094 func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
4095 run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
4096 }
4097 func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
4098 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
4099 func(tr *Transport) {
4100 tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
4101
4102 return &funcConn{
4103 read: func([]byte) (int, error) {
4104 return 0, errors.New("error")
4105 },
4106 write: func([]byte) (int, error) {
4107 return 0, errors.New("error")
4108 },
4109 }, nil
4110 }
4111 },
4112 ).ts
4113
4114
4115
4116
4117
4118 SetEnterRoundTripHook(func() {
4119 time.Sleep(1 * time.Millisecond)
4120 })
4121 defer SetEnterRoundTripHook(nil)
4122 var closes int
4123 _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
4124 if err == nil {
4125 t.Fatalf("expected request to fail, but it did not")
4126 }
4127 if closes != 1 {
4128 t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
4129 }
4130 }
4131
4132
4133
4134
4135 type logWritesConn struct {
4136 net.Conn
4137
4138 w io.Writer
4139
4140 rch <-chan io.Reader
4141 r io.Reader
4142
4143 mu sync.Mutex
4144 writes []string
4145 }
4146
4147 func (c *logWritesConn) Write(p []byte) (n int, err error) {
4148 c.mu.Lock()
4149 defer c.mu.Unlock()
4150 c.writes = append(c.writes, string(p))
4151 return c.w.Write(p)
4152 }
4153
4154 func (c *logWritesConn) Read(p []byte) (n int, err error) {
4155 if c.r == nil {
4156 c.r = <-c.rch
4157 }
4158 return c.r.Read(p)
4159 }
4160
4161 func (c *logWritesConn) Close() error { return nil }
4162
4163
4164 func TestTransportFlushesBodyChunks(t *testing.T) {
4165 defer afterTest(t)
4166 resBody := make(chan io.Reader, 1)
4167 connr, connw := io.Pipe()
4168 lw := &logWritesConn{
4169 rch: resBody,
4170 w: connw,
4171 }
4172 tr := &Transport{
4173 Dial: func(network, addr string) (net.Conn, error) {
4174 return lw, nil
4175 },
4176 }
4177 bodyr, bodyw := io.Pipe()
4178 go func() {
4179 defer bodyw.Close()
4180 for i := 0; i < 3; i++ {
4181 fmt.Fprintf(bodyw, "num%d\n", i)
4182 }
4183 }()
4184 resc := make(chan *Response)
4185 go func() {
4186 req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
4187 req.Header.Set("User-Agent", "x")
4188 res, err := tr.RoundTrip(req)
4189 if err != nil {
4190 t.Errorf("RoundTrip: %v", err)
4191 close(resc)
4192 return
4193 }
4194 resc <- res
4195
4196 }()
4197
4198 req, err := ReadRequest(bufio.NewReader(connr))
4199 if err != nil {
4200 t.Fatal(err)
4201 }
4202 io.Copy(io.Discard, req.Body)
4203
4204
4205 resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
4206 res, ok := <-resc
4207 if !ok {
4208 return
4209 }
4210 defer res.Body.Close()
4211
4212 want := []string{
4213 "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
4214 "5\r\nnum0\n\r\n",
4215 "5\r\nnum1\n\r\n",
4216 "5\r\nnum2\n\r\n",
4217 "0\r\n\r\n",
4218 }
4219 if !reflect.DeepEqual(lw.writes, want) {
4220 t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
4221 }
4222 }
4223
4224
4225 func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
4226 func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
4227 gotReq := make(chan struct{})
4228 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4229 close(gotReq)
4230 }))
4231
4232 pr, pw := io.Pipe()
4233 req, err := NewRequest("POST", cst.ts.URL, pr)
4234 if err != nil {
4235 t.Fatal(err)
4236 }
4237 gotRes := make(chan struct{})
4238 go func() {
4239 defer close(gotRes)
4240 res, err := cst.tr.RoundTrip(req)
4241 if err != nil {
4242 t.Error(err)
4243 return
4244 }
4245 res.Body.Close()
4246 }()
4247
4248 <-gotReq
4249 pw.Close()
4250 <-gotRes
4251 }
4252
4253 type wgReadCloser struct {
4254 io.Reader
4255 wg *sync.WaitGroup
4256 closed bool
4257 }
4258
4259 func (c *wgReadCloser) Close() error {
4260 if c.closed {
4261 return net.ErrClosed
4262 }
4263 c.closed = true
4264 c.wg.Done()
4265 return nil
4266 }
4267
4268
4269 func TestTransportPrefersResponseOverWriteError(t *testing.T) {
4270 run(t, testTransportPrefersResponseOverWriteError)
4271 }
4272 func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
4273 if testing.Short() {
4274 t.Skip("skipping in short mode")
4275 }
4276 const contentLengthLimit = 1024 * 1024
4277 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4278 if r.ContentLength >= contentLengthLimit {
4279 w.WriteHeader(StatusBadRequest)
4280 r.Body.Close()
4281 return
4282 }
4283 w.WriteHeader(StatusOK)
4284 })).ts
4285 c := ts.Client()
4286
4287 fail := 0
4288 count := 100
4289
4290 bigBody := strings.Repeat("a", contentLengthLimit*2)
4291 var wg sync.WaitGroup
4292 defer wg.Wait()
4293 getBody := func() (io.ReadCloser, error) {
4294 wg.Add(1)
4295 body := &wgReadCloser{
4296 Reader: strings.NewReader(bigBody),
4297 wg: &wg,
4298 }
4299 return body, nil
4300 }
4301
4302 for i := 0; i < count; i++ {
4303 reqBody, _ := getBody()
4304 req, err := NewRequest("PUT", ts.URL, reqBody)
4305 if err != nil {
4306 reqBody.Close()
4307 t.Fatal(err)
4308 }
4309 req.ContentLength = int64(len(bigBody))
4310 req.GetBody = getBody
4311
4312 resp, err := c.Do(req)
4313 if err != nil {
4314 fail++
4315 t.Logf("%d = %#v", i, err)
4316 if ue, ok := err.(*url.Error); ok {
4317 t.Logf("urlErr = %#v", ue.Err)
4318 if ne, ok := ue.Err.(*net.OpError); ok {
4319 t.Logf("netOpError = %#v", ne.Err)
4320 }
4321 }
4322 } else {
4323 resp.Body.Close()
4324 if resp.StatusCode != 400 {
4325 t.Errorf("Expected status code 400, got %v", resp.Status)
4326 }
4327 }
4328 }
4329 if fail > 0 {
4330 t.Errorf("Failed %v out of %v\n", fail, count)
4331 }
4332 }
4333
4334 func TestTransportAutomaticHTTP2(t *testing.T) {
4335 testTransportAutoHTTP(t, &Transport{}, true)
4336 }
4337
4338 func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
4339 testTransportAutoHTTP(t, &Transport{
4340 ForceAttemptHTTP2: true,
4341 TLSClientConfig: new(tls.Config),
4342 }, true)
4343 }
4344
4345
4346 func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
4347 testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
4348 }
4349
4350 func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
4351 testTransportAutoHTTP(t, &Transport{
4352 TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
4353 }, false)
4354 }
4355
4356 func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
4357 testTransportAutoHTTP(t, &Transport{
4358 TLSClientConfig: new(tls.Config),
4359 }, false)
4360 }
4361
4362 func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
4363 testTransportAutoHTTP(t, &Transport{
4364 ExpectContinueTimeout: 1 * time.Second,
4365 }, true)
4366 }
4367
4368 func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
4369 var d net.Dialer
4370 testTransportAutoHTTP(t, &Transport{
4371 Dial: d.Dial,
4372 }, false)
4373 }
4374
4375 func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
4376 var d net.Dialer
4377 testTransportAutoHTTP(t, &Transport{
4378 DialContext: d.DialContext,
4379 }, false)
4380 }
4381
4382 func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
4383 testTransportAutoHTTP(t, &Transport{
4384 DialTLS: func(network, addr string) (net.Conn, error) {
4385 panic("unused")
4386 },
4387 }, false)
4388 }
4389
4390 func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
4391 CondSkipHTTP2(t)
4392 _, err := tr.RoundTrip(new(Request))
4393 if err == nil {
4394 t.Error("expected error from RoundTrip")
4395 }
4396 if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
4397 t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
4398 }
4399 }
4400
4401
4402
4403
4404
4405
4406
4407
4408 func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
4409 run(t, testTransportReuseConnEmptyResponseBody)
4410 }
4411 func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
4412 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4413 w.Header().Set("X-Addr", r.RemoteAddr)
4414
4415 }))
4416 n := 100
4417 if testing.Short() {
4418 n = 10
4419 }
4420 var firstAddr string
4421 for i := 0; i < n; i++ {
4422 res, err := cst.c.Get(cst.ts.URL)
4423 if err != nil {
4424 log.Fatal(err)
4425 }
4426 addr := res.Header.Get("X-Addr")
4427 if i == 0 {
4428 firstAddr = addr
4429 } else if addr != firstAddr {
4430 t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
4431 }
4432 res.Body.Close()
4433 }
4434 }
4435
4436
4437 func TestNoCrashReturningTransportAltConn(t *testing.T) {
4438 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
4439 if err != nil {
4440 t.Fatal(err)
4441 }
4442 ln := newLocalListener(t)
4443 defer ln.Close()
4444
4445 var wg sync.WaitGroup
4446 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
4447 defer SetPendingDialHooks(nil, nil)
4448
4449 testDone := make(chan struct{})
4450 defer close(testDone)
4451 go func() {
4452 tln := tls.NewListener(ln, &tls.Config{
4453 NextProtos: []string{"foo"},
4454 Certificates: []tls.Certificate{cert},
4455 })
4456 sc, err := tln.Accept()
4457 if err != nil {
4458 t.Error(err)
4459 return
4460 }
4461 if err := sc.(*tls.Conn).Handshake(); err != nil {
4462 t.Error(err)
4463 return
4464 }
4465 <-testDone
4466 sc.Close()
4467 }()
4468
4469 addr := ln.Addr().String()
4470
4471 req, _ := NewRequest("GET", "https://fake.tld/", nil)
4472 cancel := make(chan struct{})
4473 req.Cancel = cancel
4474
4475 doReturned := make(chan bool, 1)
4476 madeRoundTripper := make(chan bool, 1)
4477
4478 tr := &Transport{
4479 DisableKeepAlives: true,
4480 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
4481 "foo": func(authority string, c *tls.Conn) RoundTripper {
4482 madeRoundTripper <- true
4483 return funcRoundTripper(func() {
4484 t.Error("foo RoundTripper should not be called")
4485 })
4486 },
4487 },
4488 Dial: func(_, _ string) (net.Conn, error) {
4489 panic("shouldn't be called")
4490 },
4491 DialTLS: func(_, _ string) (net.Conn, error) {
4492 tc, err := tls.Dial("tcp", addr, &tls.Config{
4493 InsecureSkipVerify: true,
4494 NextProtos: []string{"foo"},
4495 })
4496 if err != nil {
4497 return nil, err
4498 }
4499 if err := tc.Handshake(); err != nil {
4500 return nil, err
4501 }
4502 close(cancel)
4503 <-doReturned
4504 return tc, nil
4505 },
4506 }
4507 c := &Client{Transport: tr}
4508
4509 _, err = c.Do(req)
4510 if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
4511 t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
4512 }
4513
4514 doReturned <- true
4515 <-madeRoundTripper
4516 wg.Wait()
4517 }
4518
4519 func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
4520 run(t, func(t *testing.T, mode testMode) {
4521 testTransportReuseConnection_Gzip(t, mode, true)
4522 })
4523 }
4524
4525 func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
4526 run(t, func(t *testing.T, mode testMode) {
4527 testTransportReuseConnection_Gzip(t, mode, false)
4528 })
4529 }
4530
4531
4532 func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
4533 addr := make(chan string, 2)
4534 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4535 addr <- r.RemoteAddr
4536 w.Header().Set("Content-Encoding", "gzip")
4537 if chunked {
4538 w.(Flusher).Flush()
4539 }
4540 w.Write(rgz)
4541 })).ts
4542 c := ts.Client()
4543
4544 trace := &httptrace.ClientTrace{
4545 GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
4546 GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
4547 PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) },
4548 ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
4549 ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
4550 }
4551 ctx := httptrace.WithClientTrace(context.Background(), trace)
4552
4553 for i := 0; i < 2; i++ {
4554 req, _ := NewRequest("GET", ts.URL, nil)
4555 req = req.WithContext(ctx)
4556 res, err := c.Do(req)
4557 if err != nil {
4558 t.Fatal(err)
4559 }
4560 buf := make([]byte, len(rgz))
4561 if n, err := io.ReadFull(res.Body, buf); err != nil {
4562 t.Errorf("%d. ReadFull = %v, %v", i, n, err)
4563 }
4564
4565
4566
4567 }
4568 a1, a2 := <-addr, <-addr
4569 if a1 != a2 {
4570 t.Fatalf("didn't reuse connection")
4571 }
4572 }
4573
4574 func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
4575 func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
4576 if mode == http2Mode {
4577 t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
4578 }
4579 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4580 if r.URL.Path == "/long" {
4581 w.Header().Set("Long", strings.Repeat("a", 1<<20))
4582 }
4583 })).ts
4584 c := ts.Client()
4585 c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
4586
4587 if res, err := c.Get(ts.URL); err != nil {
4588 t.Fatal(err)
4589 } else {
4590 res.Body.Close()
4591 }
4592
4593 res, err := c.Get(ts.URL + "/long")
4594 if err == nil {
4595 defer res.Body.Close()
4596 var n int64
4597 for k, vv := range res.Header {
4598 for _, v := range vv {
4599 n += int64(len(k)) + int64(len(v))
4600 }
4601 }
4602 t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
4603 }
4604 if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
4605 t.Errorf("got error: %v; want %q", err, want)
4606 }
4607 }
4608
4609 func TestTransportEventTrace(t *testing.T) {
4610 run(t, func(t *testing.T, mode testMode) {
4611 testTransportEventTrace(t, mode, false)
4612 }, testNotParallel)
4613 }
4614
4615
4616 func TestTransportEventTrace_NoHooks(t *testing.T) {
4617 run(t, func(t *testing.T, mode testMode) {
4618 testTransportEventTrace(t, mode, true)
4619 }, testNotParallel)
4620 }
4621
4622 func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
4623 const resBody = "some body"
4624 gotWroteReqEvent := make(chan struct{}, 500)
4625 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4626 if r.Method == "GET" {
4627
4628 return
4629 }
4630 if _, err := io.ReadAll(r.Body); err != nil {
4631 t.Error(err)
4632 }
4633 if !noHooks {
4634 <-gotWroteReqEvent
4635 }
4636 io.WriteString(w, resBody)
4637 }), func(tr *Transport) {
4638 if tr.TLSClientConfig != nil {
4639 tr.TLSClientConfig.InsecureSkipVerify = true
4640 }
4641 })
4642 defer cst.close()
4643
4644 cst.tr.ExpectContinueTimeout = 1 * time.Second
4645
4646 var mu sync.Mutex
4647 var buf strings.Builder
4648 logf := func(format string, args ...any) {
4649 mu.Lock()
4650 defer mu.Unlock()
4651 fmt.Fprintf(&buf, format, args...)
4652 buf.WriteByte('\n')
4653 }
4654
4655 addrStr := cst.ts.Listener.Addr().String()
4656 ip, port, err := net.SplitHostPort(addrStr)
4657 if err != nil {
4658 t.Fatal(err)
4659 }
4660
4661
4662 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
4663 if host != "dns-is-faked.golang" {
4664 t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
4665 return nil, nil
4666 }
4667 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
4668 })
4669
4670 body := "some body"
4671 req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
4672 req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
4673 trace := &httptrace.ClientTrace{
4674 GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
4675 GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
4676 GotFirstResponseByte: func() { logf("first response byte") },
4677 PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
4678 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
4679 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
4680 ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
4681 ConnectDone: func(network, addr string, err error) {
4682 if err != nil {
4683 t.Errorf("ConnectDone: %v", err)
4684 }
4685 logf("ConnectDone: connected to %s %s = %v", network, addr, err)
4686 },
4687 WroteHeaderField: func(key string, value []string) {
4688 logf("WroteHeaderField: %s: %v", key, value)
4689 },
4690 WroteHeaders: func() {
4691 logf("WroteHeaders")
4692 },
4693 Wait100Continue: func() { logf("Wait100Continue") },
4694 Got100Continue: func() { logf("Got100Continue") },
4695 WroteRequest: func(e httptrace.WroteRequestInfo) {
4696 logf("WroteRequest: %+v", e)
4697 gotWroteReqEvent <- struct{}{}
4698 },
4699 }
4700 if mode == http2Mode {
4701 trace.TLSHandshakeStart = func() { logf("tls handshake start") }
4702 trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
4703 logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
4704 }
4705 }
4706 if noHooks {
4707
4708 *trace = httptrace.ClientTrace{}
4709 }
4710 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
4711
4712 req.Header.Set("Expect", "100-continue")
4713 res, err := cst.c.Do(req)
4714 if err != nil {
4715 t.Fatal(err)
4716 }
4717 logf("got roundtrip.response")
4718 slurp, err := io.ReadAll(res.Body)
4719 if err != nil {
4720 t.Fatal(err)
4721 }
4722 logf("consumed body")
4723 if string(slurp) != resBody || res.StatusCode != 200 {
4724 t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
4725 }
4726 res.Body.Close()
4727
4728 if noHooks {
4729
4730
4731
4732 return
4733 }
4734
4735 mu.Lock()
4736 got := buf.String()
4737 mu.Unlock()
4738
4739 wantOnce := func(sub string) {
4740 if strings.Count(got, sub) != 1 {
4741 t.Errorf("expected substring %q exactly once in output.", sub)
4742 }
4743 }
4744 wantOnceOrMore := func(sub string) {
4745 if strings.Count(got, sub) == 0 {
4746 t.Errorf("expected substring %q at least once in output.", sub)
4747 }
4748 }
4749 wantOnce("Getting conn for dns-is-faked.golang:" + port)
4750 wantOnce("DNS start: {Host:dns-is-faked.golang}")
4751 wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
4752 wantOnce("got conn: {")
4753 wantOnceOrMore("Connecting to tcp " + addrStr)
4754 wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
4755 wantOnce("Reused:false WasIdle:false IdleTime:0s")
4756 wantOnce("first response byte")
4757 if mode == http2Mode {
4758 wantOnce("tls handshake start")
4759 wantOnce("tls handshake done")
4760 } else {
4761 wantOnce("PutIdleConn = <nil>")
4762 wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
4763
4764
4765 wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
4766 wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
4767 wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
4768 wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
4769 }
4770 wantOnce("WroteHeaders")
4771 wantOnce("Wait100Continue")
4772 wantOnce("Got100Continue")
4773 wantOnce("WroteRequest: {Err:<nil>}")
4774 if strings.Contains(got, " to udp ") {
4775 t.Errorf("should not see UDP (DNS) connections")
4776 }
4777 if t.Failed() {
4778 t.Errorf("Output:\n%s", got)
4779 }
4780
4781
4782 req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
4783 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
4784 res, err = cst.c.Do(req)
4785 if err != nil {
4786 t.Fatal(err)
4787 }
4788 if res.StatusCode != 200 {
4789 t.Fatal(res.Status)
4790 }
4791 res.Body.Close()
4792
4793 mu.Lock()
4794 got = buf.String()
4795 mu.Unlock()
4796
4797 sub := "Getting conn for dns-is-faked.golang:"
4798 if gotn, want := strings.Count(got, sub), 2; gotn != want {
4799 t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
4800 }
4801
4802 }
4803
4804 func TestTransportEventTraceTLSVerify(t *testing.T) {
4805 run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
4806 }
4807 func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
4808 var mu sync.Mutex
4809 var buf strings.Builder
4810 logf := func(format string, args ...any) {
4811 mu.Lock()
4812 defer mu.Unlock()
4813 fmt.Fprintf(&buf, format, args...)
4814 buf.WriteByte('\n')
4815 }
4816
4817 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4818 t.Error("Unexpected request")
4819 }), func(ts *httptest.Server) {
4820 ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
4821 logf("%s", p)
4822 return len(p), nil
4823 }), "", 0)
4824 }).ts
4825
4826 certpool := x509.NewCertPool()
4827 certpool.AddCert(ts.Certificate())
4828
4829 c := &Client{Transport: &Transport{
4830 TLSClientConfig: &tls.Config{
4831 ServerName: "dns-is-faked.golang",
4832 RootCAs: certpool,
4833 },
4834 }}
4835
4836 trace := &httptrace.ClientTrace{
4837 TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
4838 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
4839 logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
4840 },
4841 }
4842
4843 req, _ := NewRequest("GET", ts.URL, nil)
4844 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
4845 _, err := c.Do(req)
4846 if err == nil {
4847 t.Error("Expected request to fail TLS verification")
4848 }
4849
4850 mu.Lock()
4851 got := buf.String()
4852 mu.Unlock()
4853
4854 wantOnce := func(sub string) {
4855 if strings.Count(got, sub) != 1 {
4856 t.Errorf("expected substring %q exactly once in output.", sub)
4857 }
4858 }
4859
4860 wantOnce("TLSHandshakeStart")
4861 wantOnce("TLSHandshakeDone")
4862 wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
4863
4864 if t.Failed() {
4865 t.Errorf("Output:\n%s", got)
4866 }
4867 }
4868
4869 var (
4870 isDNSHijackedOnce sync.Once
4871 isDNSHijacked bool
4872 )
4873
4874 func skipIfDNSHijacked(t *testing.T) {
4875
4876
4877
4878 isDNSHijackedOnce.Do(func() {
4879 addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
4880 isDNSHijacked = len(addrs) != 0
4881 })
4882 if isDNSHijacked {
4883 t.Skip("skipping; test requires non-hijacking DNS server")
4884 }
4885 }
4886
4887 func TestTransportEventTraceRealDNS(t *testing.T) {
4888 skipIfDNSHijacked(t)
4889 defer afterTest(t)
4890 tr := &Transport{}
4891 defer tr.CloseIdleConnections()
4892 c := &Client{Transport: tr}
4893
4894 var mu sync.Mutex
4895 var buf strings.Builder
4896 logf := func(format string, args ...any) {
4897 mu.Lock()
4898 defer mu.Unlock()
4899 fmt.Fprintf(&buf, format, args...)
4900 buf.WriteByte('\n')
4901 }
4902
4903 req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
4904 trace := &httptrace.ClientTrace{
4905 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
4906 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
4907 ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
4908 ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
4909 }
4910 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
4911
4912 resp, err := c.Do(req)
4913 if err == nil {
4914 resp.Body.Close()
4915 t.Fatal("expected error during DNS lookup")
4916 }
4917
4918 mu.Lock()
4919 got := buf.String()
4920 mu.Unlock()
4921
4922 wantSub := func(sub string) {
4923 if !strings.Contains(got, sub) {
4924 t.Errorf("expected substring %q in output.", sub)
4925 }
4926 }
4927 wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
4928 wantSub("DNSDone: {Addrs:[] Err:")
4929 if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
4930 t.Errorf("should not see Connect events")
4931 }
4932 if t.Failed() {
4933 t.Errorf("Output:\n%s", got)
4934 }
4935 }
4936
4937
4938 func TestTransportRejectsAlphaPort(t *testing.T) {
4939 res, err := Get("http://dummy.tld:123foo/bar")
4940 if err == nil {
4941 res.Body.Close()
4942 t.Fatal("unexpected success")
4943 }
4944 ue, ok := err.(*url.Error)
4945 if !ok {
4946 t.Fatalf("got %#v; want *url.Error", err)
4947 }
4948 got := ue.Err.Error()
4949 want := `invalid port ":123foo" after host`
4950 if got != want {
4951 t.Errorf("got error %q; want %q", got, want)
4952 }
4953 }
4954
4955
4956
4957 func TestTLSHandshakeTrace(t *testing.T) {
4958 run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
4959 }
4960 func testTLSHandshakeTrace(t *testing.T, mode testMode) {
4961 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
4962
4963 var mu sync.Mutex
4964 var start, done bool
4965 trace := &httptrace.ClientTrace{
4966 TLSHandshakeStart: func() {
4967 mu.Lock()
4968 defer mu.Unlock()
4969 start = true
4970 },
4971 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
4972 mu.Lock()
4973 defer mu.Unlock()
4974 done = true
4975 if err != nil {
4976 t.Fatal("Expected error to be nil but was:", err)
4977 }
4978 },
4979 }
4980
4981 c := ts.Client()
4982 req, err := NewRequest("GET", ts.URL, nil)
4983 if err != nil {
4984 t.Fatal("Unable to construct test request:", err)
4985 }
4986 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
4987
4988 r, err := c.Do(req)
4989 if err != nil {
4990 t.Fatal("Unexpected error making request:", err)
4991 }
4992 r.Body.Close()
4993 mu.Lock()
4994 defer mu.Unlock()
4995 if !start {
4996 t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
4997 }
4998 if !done {
4999 t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
5000 }
5001 }
5002
5003 func TestTransportMaxIdleConns(t *testing.T) {
5004 run(t, testTransportMaxIdleConns, []testMode{http1Mode})
5005 }
5006 func testTransportMaxIdleConns(t *testing.T, mode testMode) {
5007 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5008
5009 })).ts
5010 c := ts.Client()
5011 tr := c.Transport.(*Transport)
5012 tr.MaxIdleConns = 4
5013
5014 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
5015 if err != nil {
5016 t.Fatal(err)
5017 }
5018 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
5019 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5020 })
5021
5022 hitHost := func(n int) {
5023 req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
5024 req = req.WithContext(ctx)
5025 res, err := c.Do(req)
5026 if err != nil {
5027 t.Fatal(err)
5028 }
5029 res.Body.Close()
5030 }
5031 for i := 0; i < 4; i++ {
5032 hitHost(i)
5033 }
5034 want := []string{
5035 "|http|host-0.dns-is-faked.golang:" + port,
5036 "|http|host-1.dns-is-faked.golang:" + port,
5037 "|http|host-2.dns-is-faked.golang:" + port,
5038 "|http|host-3.dns-is-faked.golang:" + port,
5039 }
5040 if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
5041 t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
5042 }
5043
5044
5045 hitHost(4)
5046 want = []string{
5047 "|http|host-1.dns-is-faked.golang:" + port,
5048 "|http|host-2.dns-is-faked.golang:" + port,
5049 "|http|host-3.dns-is-faked.golang:" + port,
5050 "|http|host-4.dns-is-faked.golang:" + port,
5051 }
5052 if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
5053 t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
5054 }
5055 }
5056
5057 func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
5058 func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
5059 if testing.Short() {
5060 t.Skip("skipping in short mode")
5061 }
5062
5063 timeout := 1 * time.Millisecond
5064 timeoutLoop:
5065 for {
5066 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5067
5068 }))
5069 tr := cst.tr
5070 tr.IdleConnTimeout = timeout
5071 defer tr.CloseIdleConnections()
5072 c := &Client{Transport: tr}
5073
5074 idleConns := func() []string {
5075 if mode == http2Mode {
5076 return tr.IdleConnStrsForTesting_h2()
5077 } else {
5078 return tr.IdleConnStrsForTesting()
5079 }
5080 }
5081
5082 var conn string
5083 doReq := func(n int) (timeoutOk bool) {
5084 req, _ := NewRequest("GET", cst.ts.URL, nil)
5085 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5086 PutIdleConn: func(err error) {
5087 if err != nil {
5088 t.Errorf("failed to keep idle conn: %v", err)
5089 }
5090 },
5091 }))
5092 res, err := c.Do(req)
5093 if err != nil {
5094 if strings.Contains(err.Error(), "use of closed network connection") {
5095 t.Logf("req %v: connection closed prematurely", n)
5096 return false
5097 }
5098 }
5099 res.Body.Close()
5100 conns := idleConns()
5101 if len(conns) != 1 {
5102 if len(conns) == 0 {
5103 t.Logf("req %v: no idle conns", n)
5104 return false
5105 }
5106 t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
5107 }
5108 if conn == "" {
5109 conn = conns[0]
5110 }
5111 if conn != conns[0] {
5112 t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
5113 return false
5114 }
5115 return true
5116 }
5117 for i := 0; i < 3; i++ {
5118 if !doReq(i) {
5119 t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
5120 timeout *= 2
5121 cst.close()
5122 continue timeoutLoop
5123 }
5124 time.Sleep(timeout / 2)
5125 }
5126
5127 waitCondition(t, timeout/2, func(d time.Duration) bool {
5128 if got := idleConns(); len(got) != 0 {
5129 if d >= timeout*3/2 {
5130 t.Logf("after %v, idle conns = %q", d, got)
5131 }
5132 return false
5133 }
5134 return true
5135 })
5136 break
5137 }
5138 }
5139
5140
5141
5142
5143
5144
5145
5146
5147
5148
5149
5150
5151 func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
5152 func testIdleConnH2Crash(t *testing.T, mode testMode) {
5153 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5154
5155 }))
5156
5157 ctx, cancel := context.WithCancel(context.Background())
5158 defer cancel()
5159
5160 sawDoErr := make(chan bool, 1)
5161 testDone := make(chan struct{})
5162 defer close(testDone)
5163
5164 cst.tr.IdleConnTimeout = 5 * time.Millisecond
5165 cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
5166 c, err := tls.Dial(network, addr, &tls.Config{
5167 InsecureSkipVerify: true,
5168 NextProtos: []string{"h2"},
5169 })
5170 if err != nil {
5171 t.Error(err)
5172 return nil, err
5173 }
5174 if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
5175 t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
5176 c.Close()
5177 return nil, errors.New("bogus")
5178 }
5179
5180 cancel()
5181
5182 select {
5183 case <-sawDoErr:
5184 case <-testDone:
5185 }
5186 return c, nil
5187 }
5188
5189 req, _ := NewRequest("GET", cst.ts.URL, nil)
5190 req = req.WithContext(ctx)
5191 res, err := cst.c.Do(req)
5192 if err == nil {
5193 res.Body.Close()
5194 t.Fatal("unexpected success")
5195 }
5196 sawDoErr <- true
5197
5198
5199 time.Sleep(cst.tr.IdleConnTimeout * 10)
5200 }
5201
5202 type funcConn struct {
5203 net.Conn
5204 read func([]byte) (int, error)
5205 write func([]byte) (int, error)
5206 }
5207
5208 func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
5209 func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
5210 func (c funcConn) Close() error { return nil }
5211
5212
5213
5214 func TestTransportReturnsPeekError(t *testing.T) {
5215 errValue := errors.New("specific error value")
5216
5217 wrote := make(chan struct{})
5218 var wroteOnce sync.Once
5219
5220 tr := &Transport{
5221 Dial: func(network, addr string) (net.Conn, error) {
5222 c := funcConn{
5223 read: func([]byte) (int, error) {
5224 <-wrote
5225 return 0, errValue
5226 },
5227 write: func(p []byte) (int, error) {
5228 wroteOnce.Do(func() { close(wrote) })
5229 return len(p), nil
5230 },
5231 }
5232 return c, nil
5233 },
5234 }
5235 _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
5236 if err != errValue {
5237 t.Errorf("error = %#v; want %v", err, errValue)
5238 }
5239 }
5240
5241
5242 func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
5243 func testTransportIDNA(t *testing.T, mode testMode) {
5244 const uniDomain = "гофер.го"
5245 const punyDomain = "xn--c1ae0ajs.xn--c1aw"
5246
5247 var port string
5248 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5249 want := punyDomain + ":" + port
5250 if r.Host != want {
5251 t.Errorf("Host header = %q; want %q", r.Host, want)
5252 }
5253 if mode == http2Mode {
5254 if r.TLS == nil {
5255 t.Errorf("r.TLS == nil")
5256 } else if r.TLS.ServerName != punyDomain {
5257 t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
5258 }
5259 }
5260 w.Header().Set("Hit-Handler", "1")
5261 }), func(tr *Transport) {
5262 if tr.TLSClientConfig != nil {
5263 tr.TLSClientConfig.InsecureSkipVerify = true
5264 }
5265 })
5266
5267 ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
5268 if err != nil {
5269 t.Fatal(err)
5270 }
5271
5272
5273 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5274 if host != punyDomain {
5275 t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
5276 return nil, nil
5277 }
5278 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5279 })
5280
5281 req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
5282 trace := &httptrace.ClientTrace{
5283 GetConn: func(hostPort string) {
5284 want := net.JoinHostPort(punyDomain, port)
5285 if hostPort != want {
5286 t.Errorf("getting conn for %q; want %q", hostPort, want)
5287 }
5288 },
5289 DNSStart: func(e httptrace.DNSStartInfo) {
5290 if e.Host != punyDomain {
5291 t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
5292 }
5293 },
5294 }
5295 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5296
5297 res, err := cst.tr.RoundTrip(req)
5298 if err != nil {
5299 t.Fatal(err)
5300 }
5301 defer res.Body.Close()
5302 if res.Header.Get("Hit-Handler") != "1" {
5303 out, err := httputil.DumpResponse(res, true)
5304 if err != nil {
5305 t.Fatal(err)
5306 }
5307 t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
5308 }
5309 }
5310
5311
5312 func TestTransportProxyConnectHeader(t *testing.T) {
5313 run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
5314 }
5315 func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
5316 reqc := make(chan *Request, 1)
5317 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5318 if r.Method != "CONNECT" {
5319 t.Errorf("method = %q; want CONNECT", r.Method)
5320 }
5321 reqc <- r
5322 c, _, err := w.(Hijacker).Hijack()
5323 if err != nil {
5324 t.Errorf("Hijack: %v", err)
5325 return
5326 }
5327 c.Close()
5328 })).ts
5329
5330 c := ts.Client()
5331 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5332 return url.Parse(ts.URL)
5333 }
5334 c.Transport.(*Transport).ProxyConnectHeader = Header{
5335 "User-Agent": {"foo"},
5336 "Other": {"bar"},
5337 }
5338
5339 res, err := c.Get("https://dummy.tld/")
5340 if err == nil {
5341 res.Body.Close()
5342 t.Errorf("unexpected success")
5343 }
5344
5345 r := <-reqc
5346 if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
5347 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5348 }
5349 if got, want := r.Header.Get("Other"), "bar"; got != want {
5350 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5351 }
5352 }
5353
5354 func TestTransportProxyGetConnectHeader(t *testing.T) {
5355 run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
5356 }
5357 func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
5358 reqc := make(chan *Request, 1)
5359 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5360 if r.Method != "CONNECT" {
5361 t.Errorf("method = %q; want CONNECT", r.Method)
5362 }
5363 reqc <- r
5364 c, _, err := w.(Hijacker).Hijack()
5365 if err != nil {
5366 t.Errorf("Hijack: %v", err)
5367 return
5368 }
5369 c.Close()
5370 })).ts
5371
5372 c := ts.Client()
5373 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5374 return url.Parse(ts.URL)
5375 }
5376
5377 c.Transport.(*Transport).ProxyConnectHeader = Header{
5378 "User-Agent": {"foo"},
5379 "Other": {"bar"},
5380 }
5381 c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
5382 return Header{
5383 "User-Agent": {"foo2"},
5384 "Other": {"bar2"},
5385 }, nil
5386 }
5387
5388 res, err := c.Get("https://dummy.tld/")
5389 if err == nil {
5390 res.Body.Close()
5391 t.Errorf("unexpected success")
5392 }
5393
5394 r := <-reqc
5395 if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
5396 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5397 }
5398 if got, want := r.Header.Get("Other"), "bar2"; got != want {
5399 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5400 }
5401 }
5402
5403 var errFakeRoundTrip = errors.New("fake roundtrip")
5404
5405 type funcRoundTripper func()
5406
5407 func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
5408 fn()
5409 return nil, errFakeRoundTrip
5410 }
5411
5412 func wantBody(res *Response, err error, want string) error {
5413 if err != nil {
5414 return err
5415 }
5416 slurp, err := io.ReadAll(res.Body)
5417 if err != nil {
5418 return fmt.Errorf("error reading body: %v", err)
5419 }
5420 if string(slurp) != want {
5421 return fmt.Errorf("body = %q; want %q", slurp, want)
5422 }
5423 if err := res.Body.Close(); err != nil {
5424 return fmt.Errorf("body Close = %v", err)
5425 }
5426 return nil
5427 }
5428
5429 func newLocalListener(t *testing.T) net.Listener {
5430 ln, err := net.Listen("tcp", "127.0.0.1:0")
5431 if err != nil {
5432 ln, err = net.Listen("tcp6", "[::1]:0")
5433 }
5434 if err != nil {
5435 t.Fatal(err)
5436 }
5437 return ln
5438 }
5439
5440 type countCloseReader struct {
5441 n *int
5442 io.Reader
5443 }
5444
5445 func (cr countCloseReader) Close() error {
5446 (*cr.n)++
5447 return nil
5448 }
5449
5450
5451 var rgz = []byte{
5452 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
5453 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
5454 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
5455 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
5456 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
5457 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
5458 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
5459 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
5460 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
5461 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
5462 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
5463 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
5464 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
5465 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
5466 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5467 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
5468 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
5469 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
5470 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
5471 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5472 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
5473 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
5474 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
5475 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
5476 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
5477 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
5478 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
5479 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
5480 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
5481 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5482 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5483 0x00, 0x00,
5484 }
5485
5486
5487
5488 func TestMissingStatusNoPanic(t *testing.T) {
5489 t.Parallel()
5490
5491 const want = "unknown status code"
5492
5493 ln := newLocalListener(t)
5494 addr := ln.Addr().String()
5495 done := make(chan bool)
5496 fullAddrURL := fmt.Sprintf("http://%s", addr)
5497 raw := "HTTP/1.1 400\r\n" +
5498 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
5499 "Content-Type: text/html; charset=utf-8\r\n" +
5500 "Content-Length: 10\r\n" +
5501 "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
5502 "Vary: Accept-Encoding\r\n\r\n" +
5503 "Aloha Olaa"
5504
5505 go func() {
5506 defer close(done)
5507
5508 conn, _ := ln.Accept()
5509 if conn != nil {
5510 io.WriteString(conn, raw)
5511 io.ReadAll(conn)
5512 conn.Close()
5513 }
5514 }()
5515
5516 proxyURL, err := url.Parse(fullAddrURL)
5517 if err != nil {
5518 t.Fatalf("proxyURL: %v", err)
5519 }
5520
5521 tr := &Transport{Proxy: ProxyURL(proxyURL)}
5522
5523 req, _ := NewRequest("GET", "https://golang.org/", nil)
5524 res, err, panicked := doFetchCheckPanic(tr, req)
5525 if panicked {
5526 t.Error("panicked, expecting an error")
5527 }
5528 if res != nil && res.Body != nil {
5529 io.Copy(io.Discard, res.Body)
5530 res.Body.Close()
5531 }
5532
5533 if err == nil || !strings.Contains(err.Error(), want) {
5534 t.Errorf("got=%v want=%q", err, want)
5535 }
5536
5537 ln.Close()
5538 <-done
5539 }
5540
5541 func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
5542 defer func() {
5543 if r := recover(); r != nil {
5544 panicked = true
5545 }
5546 }()
5547 res, err = tr.RoundTrip(req)
5548 return
5549 }
5550
5551
5552
5553 func TestNoBodyOnChunked304Response(t *testing.T) {
5554 run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
5555 }
5556 func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
5557 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5558 conn, buf, _ := w.(Hijacker).Hijack()
5559 buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
5560 buf.Flush()
5561 conn.Close()
5562 }))
5563
5564
5565
5566
5567
5568 cst.tr.DisableKeepAlives = true
5569
5570 res, err := cst.c.Get(cst.ts.URL)
5571 if err != nil {
5572 t.Fatal(err)
5573 }
5574
5575 if res.Body != NoBody {
5576 t.Errorf("Unexpected body on 304 response")
5577 }
5578 }
5579
5580 type funcWriter func([]byte) (int, error)
5581
5582 func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
5583
5584 type doneContext struct {
5585 context.Context
5586 err error
5587 }
5588
5589 func (doneContext) Done() <-chan struct{} {
5590 c := make(chan struct{})
5591 close(c)
5592 return c
5593 }
5594
5595 func (d doneContext) Err() error { return d.err }
5596
5597
5598 func TestTransportCheckContextDoneEarly(t *testing.T) {
5599 tr := &Transport{}
5600 req, _ := NewRequest("GET", "http://fake.example/", nil)
5601 wantErr := errors.New("some error")
5602 req = req.WithContext(doneContext{context.Background(), wantErr})
5603 _, err := tr.RoundTrip(req)
5604 if err != wantErr {
5605 t.Errorf("error = %v; want %v", err, wantErr)
5606 }
5607 }
5608
5609
5610
5611
5612
5613
5614 func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
5615 run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
5616 }
5617 func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
5618 timeout := 1 * time.Millisecond
5619 for {
5620 inHandler := make(chan bool)
5621 cancelHandler := make(chan struct{})
5622 handlerDone := make(chan bool)
5623 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5624 <-r.Context().Done()
5625
5626 select {
5627 case <-cancelHandler:
5628 return
5629 case inHandler <- true:
5630 }
5631 defer func() { handlerDone <- true }()
5632
5633
5634 conn, _, err := w.(Hijacker).Hijack()
5635 if err != nil {
5636 t.Error(err)
5637 return
5638 }
5639 n, err := conn.Read([]byte{0})
5640 if n != 0 || err != io.EOF {
5641 t.Errorf("unexpected Read result: %v, %v", n, err)
5642 }
5643 conn.Close()
5644 }))
5645
5646 cst.c.Timeout = timeout
5647
5648 _, err := cst.c.Get(cst.ts.URL)
5649 if err == nil {
5650 close(cancelHandler)
5651 t.Fatal("unexpected Get success")
5652 }
5653
5654 tooSlow := time.NewTimer(timeout * 10)
5655 select {
5656 case <-tooSlow.C:
5657
5658
5659
5660 t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
5661 close(cancelHandler)
5662 cst.close()
5663 timeout *= 2
5664 continue
5665 case <-inHandler:
5666 tooSlow.Stop()
5667 <-handlerDone
5668 }
5669 break
5670 }
5671 }
5672
5673
5674
5675
5676
5677
5678 func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
5679 run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
5680 }
5681 func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
5682 inHandler := make(chan bool)
5683 cancelHandler := make(chan struct{})
5684 handlerDone := make(chan bool)
5685 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5686 w.Header().Set("Content-Length", "100")
5687 w.(Flusher).Flush()
5688
5689 select {
5690 case <-cancelHandler:
5691 return
5692 case inHandler <- true:
5693 }
5694 defer func() { handlerDone <- true }()
5695
5696 conn, _, err := w.(Hijacker).Hijack()
5697 if err != nil {
5698 t.Error(err)
5699 return
5700 }
5701 conn.Write([]byte("foo"))
5702
5703 n, err := conn.Read([]byte{0})
5704
5705
5706
5707
5708
5709 if n != 0 || err == nil {
5710 t.Errorf("unexpected Read result: %v, %v", n, err)
5711 }
5712 conn.Close()
5713 }))
5714
5715
5716
5717
5718
5719 cst.c.Timeout = 24 * time.Hour
5720 req, _ := NewRequest("GET", cst.ts.URL, nil)
5721 cancelReq := make(chan struct{})
5722 req.Cancel = cancelReq
5723
5724 res, err := cst.c.Do(req)
5725 if err != nil {
5726 close(cancelHandler)
5727 t.Fatalf("Get error: %v", err)
5728 }
5729
5730
5731
5732
5733 close(cancelReq)
5734 got, err := io.ReadAll(res.Body)
5735 if err == nil {
5736 t.Errorf("unexpected success; read %q, nil", got)
5737 }
5738
5739
5740 <-inHandler
5741 <-handlerDone
5742 }
5743
5744 func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
5745 run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
5746 }
5747 func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
5748 done := make(chan struct{})
5749 defer close(done)
5750 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5751 conn, _, err := w.(Hijacker).Hijack()
5752 if err != nil {
5753 t.Error(err)
5754 return
5755 }
5756 defer conn.Close()
5757 io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
5758 bs := bufio.NewScanner(conn)
5759 bs.Scan()
5760 fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
5761 <-done
5762 }))
5763
5764 req, _ := NewRequest("GET", cst.ts.URL, nil)
5765 req.Header.Set("Upgrade", "foo")
5766 req.Header.Set("Connection", "upgrade")
5767 res, err := cst.c.Do(req)
5768 if err != nil {
5769 t.Fatal(err)
5770 }
5771 if res.StatusCode != 101 {
5772 t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
5773 }
5774 rwc, ok := res.Body.(io.ReadWriteCloser)
5775 if !ok {
5776 t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
5777 }
5778 defer rwc.Close()
5779 bs := bufio.NewScanner(rwc)
5780 if !bs.Scan() {
5781 t.Fatalf("expected readable input")
5782 }
5783 if got, want := bs.Text(), "Some buffered data"; got != want {
5784 t.Errorf("read %q; want %q", got, want)
5785 }
5786 io.WriteString(rwc, "echo\n")
5787 if !bs.Scan() {
5788 t.Fatalf("expected another line")
5789 }
5790 if got, want := bs.Text(), "ECHO"; got != want {
5791 t.Errorf("read %q; want %q", got, want)
5792 }
5793 }
5794
5795 func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
5796 func testTransportCONNECTBidi(t *testing.T, mode testMode) {
5797 const target = "backend:443"
5798 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5799 if r.Method != "CONNECT" {
5800 t.Errorf("unexpected method %q", r.Method)
5801 w.WriteHeader(500)
5802 return
5803 }
5804 if r.RequestURI != target {
5805 t.Errorf("unexpected CONNECT target %q", r.RequestURI)
5806 w.WriteHeader(500)
5807 return
5808 }
5809 nc, brw, err := w.(Hijacker).Hijack()
5810 if err != nil {
5811 t.Error(err)
5812 return
5813 }
5814 defer nc.Close()
5815 nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
5816
5817 for {
5818 line, err := brw.ReadString('\n')
5819 if err != nil {
5820 if err != io.EOF {
5821 t.Error(err)
5822 }
5823 return
5824 }
5825 io.WriteString(brw, strings.ToUpper(line))
5826 brw.Flush()
5827 }
5828 }))
5829 pr, pw := io.Pipe()
5830 defer pw.Close()
5831 req, err := NewRequest("CONNECT", cst.ts.URL, pr)
5832 if err != nil {
5833 t.Fatal(err)
5834 }
5835 req.URL.Opaque = target
5836 res, err := cst.c.Do(req)
5837 if err != nil {
5838 t.Fatal(err)
5839 }
5840 defer res.Body.Close()
5841 if res.StatusCode != 200 {
5842 t.Fatalf("status code = %d; want 200", res.StatusCode)
5843 }
5844 br := bufio.NewReader(res.Body)
5845 for _, str := range []string{"foo", "bar", "baz"} {
5846 fmt.Fprintf(pw, "%s\n", str)
5847 got, err := br.ReadString('\n')
5848 if err != nil {
5849 t.Fatal(err)
5850 }
5851 got = strings.TrimSpace(got)
5852 want := strings.ToUpper(str)
5853 if got != want {
5854 t.Fatalf("got %q; want %q", got, want)
5855 }
5856 }
5857 }
5858
5859 func TestTransportRequestReplayable(t *testing.T) {
5860 someBody := io.NopCloser(strings.NewReader(""))
5861 tests := []struct {
5862 name string
5863 req *Request
5864 want bool
5865 }{
5866 {
5867 name: "GET",
5868 req: &Request{Method: "GET"},
5869 want: true,
5870 },
5871 {
5872 name: "GET_http.NoBody",
5873 req: &Request{Method: "GET", Body: NoBody},
5874 want: true,
5875 },
5876 {
5877 name: "GET_body",
5878 req: &Request{Method: "GET", Body: someBody},
5879 want: false,
5880 },
5881 {
5882 name: "POST",
5883 req: &Request{Method: "POST"},
5884 want: false,
5885 },
5886 {
5887 name: "POST_idempotency-key",
5888 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
5889 want: true,
5890 },
5891 {
5892 name: "POST_x-idempotency-key",
5893 req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
5894 want: true,
5895 },
5896 {
5897 name: "POST_body",
5898 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
5899 want: false,
5900 },
5901 }
5902 for _, tt := range tests {
5903 t.Run(tt.name, func(t *testing.T) {
5904 got := tt.req.ExportIsReplayable()
5905 if got != tt.want {
5906 t.Errorf("replyable = %v; want %v", got, tt.want)
5907 }
5908 })
5909 }
5910 }
5911
5912
5913
5914 type testMockTCPConn struct {
5915 *net.TCPConn
5916
5917 ReadFromCalled bool
5918 }
5919
5920 func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
5921 c.ReadFromCalled = true
5922 return c.TCPConn.ReadFrom(r)
5923 }
5924
5925 func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
5926 func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
5927 nBytes := int64(1 << 10)
5928 newFileFunc := func() (r io.Reader, done func(), err error) {
5929 f, err := os.CreateTemp("", "net-http-newfilefunc")
5930 if err != nil {
5931 return nil, nil, err
5932 }
5933
5934
5935 if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
5936 return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
5937 }
5938 if _, err := f.Seek(0, 0); err != nil {
5939 return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
5940 }
5941
5942 done = func() {
5943 f.Close()
5944 os.Remove(f.Name())
5945 }
5946
5947 return f, done, nil
5948 }
5949
5950 newBufferFunc := func() (io.Reader, func(), error) {
5951 return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
5952 }
5953
5954 cases := []struct {
5955 name string
5956 readerFunc func() (io.Reader, func(), error)
5957 contentLength int64
5958 expectedReadFrom bool
5959 }{
5960 {
5961 name: "file, length",
5962 readerFunc: newFileFunc,
5963 contentLength: nBytes,
5964 expectedReadFrom: true,
5965 },
5966 {
5967 name: "file, no length",
5968 readerFunc: newFileFunc,
5969 },
5970 {
5971 name: "file, negative length",
5972 readerFunc: newFileFunc,
5973 contentLength: -1,
5974 },
5975 {
5976 name: "buffer",
5977 contentLength: nBytes,
5978 readerFunc: newBufferFunc,
5979 },
5980 {
5981 name: "buffer, no length",
5982 readerFunc: newBufferFunc,
5983 },
5984 {
5985 name: "buffer, length -1",
5986 contentLength: -1,
5987 readerFunc: newBufferFunc,
5988 },
5989 }
5990
5991 for _, tc := range cases {
5992 t.Run(tc.name, func(t *testing.T) {
5993 r, cleanup, err := tc.readerFunc()
5994 if err != nil {
5995 t.Fatal(err)
5996 }
5997 defer cleanup()
5998
5999 tConn := &testMockTCPConn{}
6000 trFunc := func(tr *Transport) {
6001 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
6002 var d net.Dialer
6003 conn, err := d.DialContext(ctx, network, addr)
6004 if err != nil {
6005 return nil, err
6006 }
6007
6008 tcpConn, ok := conn.(*net.TCPConn)
6009 if !ok {
6010 return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
6011 }
6012
6013 tConn.TCPConn = tcpConn
6014 return tConn, nil
6015 }
6016 }
6017
6018 cst := newClientServerTest(
6019 t,
6020 mode,
6021 HandlerFunc(func(w ResponseWriter, r *Request) {
6022 io.Copy(io.Discard, r.Body)
6023 r.Body.Close()
6024 w.WriteHeader(200)
6025 }),
6026 trFunc,
6027 )
6028
6029 req, err := NewRequest("PUT", cst.ts.URL, r)
6030 if err != nil {
6031 t.Fatal(err)
6032 }
6033 req.ContentLength = tc.contentLength
6034 req.Header.Set("Content-Type", "application/octet-stream")
6035 resp, err := cst.c.Do(req)
6036 if err != nil {
6037 t.Fatal(err)
6038 }
6039 defer resp.Body.Close()
6040 if resp.StatusCode != 200 {
6041 t.Fatalf("status code = %d; want 200", resp.StatusCode)
6042 }
6043
6044 expectedReadFrom := tc.expectedReadFrom
6045 if mode != http1Mode {
6046 expectedReadFrom = false
6047 }
6048 if !tConn.ReadFromCalled && expectedReadFrom {
6049 t.Fatalf("did not call ReadFrom")
6050 }
6051
6052 if tConn.ReadFromCalled && !expectedReadFrom {
6053 t.Fatalf("ReadFrom was unexpectedly invoked")
6054 }
6055 })
6056 }
6057 }
6058
6059 func TestTransportClone(t *testing.T) {
6060 tr := &Transport{
6061 Proxy: func(*Request) (*url.URL, error) { panic("") },
6062 OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
6063 return nil
6064 },
6065 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6066 Dial: func(network, addr string) (net.Conn, error) { panic("") },
6067 DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
6068 DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6069 TLSClientConfig: new(tls.Config),
6070 TLSHandshakeTimeout: time.Second,
6071 DisableKeepAlives: true,
6072 DisableCompression: true,
6073 MaxIdleConns: 1,
6074 MaxIdleConnsPerHost: 1,
6075 MaxConnsPerHost: 1,
6076 IdleConnTimeout: time.Second,
6077 ResponseHeaderTimeout: time.Second,
6078 ExpectContinueTimeout: time.Second,
6079 ProxyConnectHeader: Header{},
6080 GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
6081 MaxResponseHeaderBytes: 1,
6082 ForceAttemptHTTP2: true,
6083 TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
6084 "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
6085 },
6086 ReadBufferSize: 1,
6087 WriteBufferSize: 1,
6088 }
6089 tr2 := tr.Clone()
6090 rv := reflect.ValueOf(tr2).Elem()
6091 rt := rv.Type()
6092 for i := 0; i < rt.NumField(); i++ {
6093 sf := rt.Field(i)
6094 if !token.IsExported(sf.Name) {
6095 continue
6096 }
6097 if rv.Field(i).IsZero() {
6098 t.Errorf("cloned field t2.%s is zero", sf.Name)
6099 }
6100 }
6101
6102 if _, ok := tr2.TLSNextProto["foo"]; !ok {
6103 t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
6104 }
6105
6106
6107 tr = new(Transport)
6108 tr2 = tr.Clone()
6109 if tr2.TLSNextProto != nil {
6110 t.Errorf("Transport.TLSNextProto unexpected non-nil")
6111 }
6112 }
6113
6114 func TestIs408(t *testing.T) {
6115 tests := []struct {
6116 in string
6117 want bool
6118 }{
6119 {"HTTP/1.0 408", true},
6120 {"HTTP/1.1 408", true},
6121 {"HTTP/1.8 408", true},
6122 {"HTTP/2.0 408", false},
6123 {"HTTP/1.1 408 ", true},
6124 {"HTTP/1.1 40", false},
6125 {"http/1.0 408", false},
6126 {"HTTP/1-1 408", false},
6127 }
6128 for _, tt := range tests {
6129 if got := Export_is408Message([]byte(tt.in)); got != tt.want {
6130 t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
6131 }
6132 }
6133 }
6134
6135 func TestTransportIgnores408(t *testing.T) {
6136 run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
6137 }
6138 func testTransportIgnores408(t *testing.T, mode testMode) {
6139
6140 defer log.SetOutput(log.Writer())
6141
6142 var logout strings.Builder
6143 log.SetOutput(&logout)
6144
6145 const target = "backend:443"
6146
6147 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6148 nc, _, err := w.(Hijacker).Hijack()
6149 if err != nil {
6150 t.Error(err)
6151 return
6152 }
6153 defer nc.Close()
6154 nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
6155 nc.Write([]byte("HTTP/1.1 408 bye\r\n"))
6156 }))
6157 req, err := NewRequest("GET", cst.ts.URL, nil)
6158 if err != nil {
6159 t.Fatal(err)
6160 }
6161 res, err := cst.c.Do(req)
6162 if err != nil {
6163 t.Fatal(err)
6164 }
6165 slurp, err := io.ReadAll(res.Body)
6166 if err != nil {
6167 t.Fatal(err)
6168 }
6169 if err != nil {
6170 t.Fatal(err)
6171 }
6172 if string(slurp) != "ok" {
6173 t.Fatalf("got %q; want ok", slurp)
6174 }
6175
6176 waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
6177 if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
6178 if d > 0 {
6179 t.Logf("%v idle conns still present after %v", n, d)
6180 }
6181 return false
6182 }
6183 return true
6184 })
6185 if got := logout.String(); got != "" {
6186 t.Fatalf("expected no log output; got: %s", got)
6187 }
6188 }
6189
6190 func TestInvalidHeaderResponse(t *testing.T) {
6191 run(t, testInvalidHeaderResponse, []testMode{http1Mode})
6192 }
6193 func testInvalidHeaderResponse(t *testing.T, mode testMode) {
6194 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6195 conn, buf, _ := w.(Hijacker).Hijack()
6196 buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
6197 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6198 "Content-Type: text/html; charset=utf-8\r\n" +
6199 "Content-Length: 0\r\n" +
6200 "Foo : bar\r\n\r\n"))
6201 buf.Flush()
6202 conn.Close()
6203 }))
6204 res, err := cst.c.Get(cst.ts.URL)
6205 if err != nil {
6206 t.Fatal(err)
6207 }
6208 defer res.Body.Close()
6209 if v := res.Header.Get("Foo"); v != "" {
6210 t.Errorf(`unexpected "Foo" header: %q`, v)
6211 }
6212 if v := res.Header.Get("Foo "); v != "bar" {
6213 t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
6214 }
6215 }
6216
6217 type bodyCloser bool
6218
6219 func (bc *bodyCloser) Close() error {
6220 *bc = true
6221 return nil
6222 }
6223 func (bc *bodyCloser) Read(b []byte) (n int, err error) {
6224 return 0, io.EOF
6225 }
6226
6227
6228
6229 func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
6230 run(t, testTransportClosesBodyOnInvalidRequests)
6231 }
6232 func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
6233 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6234 t.Errorf("Should not have been invoked")
6235 })).ts
6236
6237 u, _ := url.Parse(cst.URL)
6238
6239 tests := []struct {
6240 name string
6241 req *Request
6242 wantErr string
6243 }{
6244 {
6245 name: "invalid method",
6246 req: &Request{
6247 Method: " ",
6248 URL: u,
6249 },
6250 wantErr: `invalid method " "`,
6251 },
6252 {
6253 name: "nil URL",
6254 req: &Request{
6255 Method: "GET",
6256 },
6257 wantErr: `nil Request.URL`,
6258 },
6259 {
6260 name: "invalid header key",
6261 req: &Request{
6262 Method: "GET",
6263 Header: Header{"💡": {"emoji"}},
6264 URL: u,
6265 },
6266 wantErr: `invalid header field name "💡"`,
6267 },
6268 {
6269 name: "invalid header value",
6270 req: &Request{
6271 Method: "POST",
6272 Header: Header{"key": {"\x19"}},
6273 URL: u,
6274 },
6275 wantErr: `invalid header field value for "key"`,
6276 },
6277 {
6278 name: "non HTTP(s) scheme",
6279 req: &Request{
6280 Method: "POST",
6281 URL: &url.URL{Scheme: "faux"},
6282 },
6283 wantErr: `unsupported protocol scheme "faux"`,
6284 },
6285 {
6286 name: "no Host in URL",
6287 req: &Request{
6288 Method: "POST",
6289 URL: &url.URL{Scheme: "http"},
6290 },
6291 wantErr: `no Host in request URL`,
6292 },
6293 }
6294
6295 for _, tt := range tests {
6296 t.Run(tt.name, func(t *testing.T) {
6297 var bc bodyCloser
6298 req := tt.req
6299 req.Body = &bc
6300 _, err := cst.Client().Do(tt.req)
6301 if err == nil {
6302 t.Fatal("Expected an error")
6303 }
6304 if !bc {
6305 t.Fatal("Expected body to have been closed")
6306 }
6307 if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
6308 t.Fatalf("Error mismatch: %q does not end with %q", g, w)
6309 }
6310 })
6311 }
6312 }
6313
6314
6315
6316 type breakableConn struct {
6317 net.Conn
6318 *brokenState
6319 }
6320
6321 type brokenState struct {
6322 sync.Mutex
6323 broken bool
6324 }
6325
6326 func (w *breakableConn) Write(b []byte) (n int, err error) {
6327 w.Lock()
6328 defer w.Unlock()
6329 if w.broken {
6330 return 0, errors.New("some write error")
6331 }
6332 return w.Conn.Write(b)
6333 }
6334
6335
6336 func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
6337 run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
6338 }
6339 func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
6340 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
6341
6342 var brokenState brokenState
6343
6344 const numReqs = 5
6345 var numDials, gotConns uint32
6346
6347 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
6348 atomic.AddUint32(&numDials, 1)
6349 c, err := net.Dial(netw, addr)
6350 if err != nil {
6351 t.Errorf("unexpected Dial error: %v", err)
6352 return nil, err
6353 }
6354 return &breakableConn{c, &brokenState}, err
6355 }
6356
6357 for i := 1; i <= numReqs; i++ {
6358 brokenState.Lock()
6359 brokenState.broken = false
6360 brokenState.Unlock()
6361
6362
6363
6364
6365 doBreak := i != numReqs
6366
6367 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6368 GotConn: func(info httptrace.GotConnInfo) {
6369 t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
6370 atomic.AddUint32(&gotConns, 1)
6371 },
6372 TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
6373 brokenState.Lock()
6374 defer brokenState.Unlock()
6375 if doBreak {
6376 brokenState.broken = true
6377 }
6378 },
6379 })
6380 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6381 if err != nil {
6382 t.Fatal(err)
6383 }
6384 _, err = cst.c.Do(req)
6385 if doBreak != (err != nil) {
6386 t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
6387 }
6388 }
6389 if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
6390 t.Errorf("GotConn calls = %v; want %v", got, want)
6391 }
6392 if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
6393 t.Errorf("Dials = %v; want %v", got, want)
6394 }
6395 }
6396
6397
6398
6399
6400
6401 func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
6402 run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
6403 }
6404 func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
6405 CondSkipHTTP2(t)
6406
6407 h := HandlerFunc(func(w ResponseWriter, r *Request) {
6408 _, err := w.Write([]byte("foo"))
6409 if err != nil {
6410 t.Fatalf("Write: %v", err)
6411 }
6412 })
6413
6414 ts := newClientServerTest(t, mode, h).ts
6415
6416 c := ts.Client()
6417 tr := c.Transport.(*Transport)
6418 tr.MaxConnsPerHost = 1
6419
6420 errCh := make(chan error, 300)
6421 doReq := func() {
6422 resp, err := c.Get(ts.URL)
6423 if err != nil {
6424 errCh <- fmt.Errorf("request failed: %v", err)
6425 return
6426 }
6427 defer resp.Body.Close()
6428 _, err = io.ReadAll(resp.Body)
6429 if err != nil {
6430 errCh <- fmt.Errorf("read body failed: %v", err)
6431 }
6432 }
6433
6434 var wg sync.WaitGroup
6435 for i := 0; i < 300; i++ {
6436 wg.Add(1)
6437 go func() {
6438 defer wg.Done()
6439 doReq()
6440 }()
6441 }
6442 wg.Wait()
6443 close(errCh)
6444
6445 for err := range errCh {
6446 t.Errorf("error occurred: %v", err)
6447 }
6448 }
6449
6450
6451
6452
6453 func TestAltProtoCancellation(t *testing.T) {
6454 defer afterTest(t)
6455 tr := &Transport{}
6456 c := &Client{
6457 Transport: tr,
6458 Timeout: time.Millisecond,
6459 }
6460 tr.RegisterProtocol("cancel", cancelProto{})
6461 _, err := c.Get("cancel://bar.com/path")
6462 if err == nil {
6463 t.Error("request unexpectedly succeeded")
6464 } else if !strings.Contains(err.Error(), errCancelProto.Error()) {
6465 t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
6466 }
6467 }
6468
6469 var errCancelProto = errors.New("canceled as expected")
6470
6471 type cancelProto struct{}
6472
6473 func (cancelProto) RoundTrip(req *Request) (*Response, error) {
6474 <-req.Cancel
6475 return nil, errCancelProto
6476 }
6477
6478 type roundTripFunc func(r *Request) (*Response, error)
6479
6480 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
6481
6482
6483 func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
6484 func testIssue32441(t *testing.T, mode testMode) {
6485 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6486 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6487 t.Error("body length is zero")
6488 }
6489 })).ts
6490 c := ts.Client()
6491 c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
6492
6493 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6494 t.Error("body length is zero during round trip")
6495 }
6496 return nil, ErrSkipAltProtocol
6497 }))
6498 if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
6499 t.Error(err)
6500 }
6501 }
6502
6503
6504
6505 func TestTransportRejectsSignInContentLength(t *testing.T) {
6506 run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
6507 }
6508 func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
6509 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6510 w.Header().Set("Content-Length", "+3")
6511 w.Write([]byte("abc"))
6512 })).ts
6513
6514 c := cst.Client()
6515 res, err := c.Get(cst.URL)
6516 if err == nil || res != nil {
6517 t.Fatal("Expected a non-nil error and a nil http.Response")
6518 }
6519 if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
6520 t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
6521 }
6522 }
6523
6524
6525 type dumpConn struct {
6526 io.Writer
6527 io.Reader
6528 }
6529
6530 func (c *dumpConn) Close() error { return nil }
6531 func (c *dumpConn) LocalAddr() net.Addr { return nil }
6532 func (c *dumpConn) RemoteAddr() net.Addr { return nil }
6533 func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
6534 func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
6535 func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
6536
6537
6538
6539 type delegateReader struct {
6540 c chan io.Reader
6541 r io.Reader
6542 }
6543
6544 func (r *delegateReader) Read(p []byte) (int, error) {
6545 if r.r == nil {
6546 var ok bool
6547 if r.r, ok = <-r.c; !ok {
6548 return 0, errors.New("delegate closed")
6549 }
6550 }
6551 return r.r.Read(p)
6552 }
6553
6554 func testTransportRace(req *Request) {
6555 save := req.Body
6556 pr, pw := io.Pipe()
6557 defer pr.Close()
6558 defer pw.Close()
6559 dr := &delegateReader{c: make(chan io.Reader)}
6560
6561 t := &Transport{
6562 Dial: func(net, addr string) (net.Conn, error) {
6563 return &dumpConn{pw, dr}, nil
6564 },
6565 }
6566 defer t.CloseIdleConnections()
6567
6568 quitReadCh := make(chan struct{})
6569
6570 go func() {
6571 defer close(quitReadCh)
6572
6573 req, err := ReadRequest(bufio.NewReader(pr))
6574 if err == nil {
6575
6576
6577 io.Copy(io.Discard, req.Body)
6578 req.Body.Close()
6579 }
6580 select {
6581 case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
6582 case quitReadCh <- struct{}{}:
6583
6584 close(dr.c)
6585 }
6586 }()
6587
6588 t.RoundTrip(req)
6589
6590
6591
6592 pw.Close()
6593 <-quitReadCh
6594
6595 req.Body = save
6596 }
6597
6598
6599
6600
6601
6602 func TestErrorWriteLoopRace(t *testing.T) {
6603 if testing.Short() {
6604 return
6605 }
6606 t.Parallel()
6607 for i := 0; i < 1000; i++ {
6608 delay := time.Duration(mrand.Intn(5)) * time.Millisecond
6609 ctx, cancel := context.WithTimeout(context.Background(), delay)
6610 defer cancel()
6611
6612 r := bytes.NewBuffer(make([]byte, 10000))
6613 req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
6614 if err != nil {
6615 t.Fatal(err)
6616 }
6617
6618 testTransportRace(req)
6619 }
6620 }
6621
6622
6623
6624
6625 func TestCancelRequestWhenSharingConnection(t *testing.T) {
6626 run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
6627 }
6628 func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
6629 reqc := make(chan chan struct{}, 2)
6630 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
6631 ch := make(chan struct{}, 1)
6632 reqc <- ch
6633 <-ch
6634 w.Header().Add("Content-Length", "0")
6635 })).ts
6636
6637 client := ts.Client()
6638 transport := client.Transport.(*Transport)
6639 transport.MaxIdleConns = 1
6640 transport.MaxConnsPerHost = 1
6641
6642 var wg sync.WaitGroup
6643
6644 wg.Add(1)
6645 putidlec := make(chan chan struct{}, 1)
6646 reqerrc := make(chan error, 1)
6647 go func() {
6648 defer wg.Done()
6649 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6650 PutIdleConn: func(error) {
6651
6652
6653 ch := make(chan struct{})
6654 putidlec <- ch
6655 close(putidlec)
6656 <-ch
6657 },
6658 })
6659 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
6660 res, err := client.Do(req)
6661 reqerrc <- err
6662 if err == nil {
6663 res.Body.Close()
6664 }
6665 }()
6666
6667
6668
6669 r1c := <-reqc
6670 close(r1c)
6671 var idlec chan struct{}
6672 select {
6673 case err := <-reqerrc:
6674 if err != nil {
6675 t.Fatalf("request 1: got err %v, want nil", err)
6676 }
6677 idlec = <-putidlec
6678 case idlec = <-putidlec:
6679 }
6680
6681 wg.Add(1)
6682 cancelctx, cancel := context.WithCancel(context.Background())
6683 go func() {
6684 defer wg.Done()
6685 req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
6686 res, err := client.Do(req)
6687 if err == nil {
6688 res.Body.Close()
6689 }
6690 if !errors.Is(err, context.Canceled) {
6691 t.Errorf("request 2: got err %v, want Canceled", err)
6692 }
6693
6694
6695 close(idlec)
6696 }()
6697
6698
6699
6700 r2c := <-reqc
6701 cancel()
6702
6703 <-idlec
6704
6705 close(r2c)
6706 wg.Wait()
6707 }
6708
6709 func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
6710 func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
6711 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6712 go io.Copy(io.Discard, req.Body)
6713 panic(ErrAbortHandler)
6714 })).ts
6715
6716 var wg sync.WaitGroup
6717 for i := 0; i < 2; i++ {
6718 wg.Add(1)
6719 go func() {
6720 defer wg.Done()
6721 for j := 0; j < 10; j++ {
6722 const reqLen = 6 * 1024 * 1024
6723 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
6724 req.ContentLength = reqLen
6725 resp, _ := ts.Client().Transport.RoundTrip(req)
6726 if resp != nil {
6727 resp.Body.Close()
6728 }
6729 }
6730 }()
6731 }
6732 wg.Wait()
6733 }
6734
6735 func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
6736 func testRequestSanitization(t *testing.T, mode testMode) {
6737 if mode == http2Mode {
6738
6739 t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
6740 }
6741 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6742 if h, ok := req.Header["X-Evil"]; ok {
6743 t.Errorf("request has X-Evil header: %q", h)
6744 }
6745 })).ts
6746 req, _ := NewRequest("GET", ts.URL, nil)
6747 req.Host = "go.dev\r\nX-Evil:evil"
6748 resp, _ := ts.Client().Do(req)
6749 if resp != nil {
6750 resp.Body.Close()
6751 }
6752 }
6753
View as plain text