// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Tests for transport.go. // // More tests are in clientserver_test.go (for things testing both client & server for both // HTTP/1 and HTTP/2). This package http_test import ( "bufio" "bytes" "compress/gzip" "context" "crypto/rand" "crypto/tls" "crypto/x509" "encoding/binary" "errors" "fmt" "go/token" "internal/nettrace" "io" "log" mrand "math/rand" "net" . "net/http" "net/http/httptest" "net/http/httptrace" "net/http/httputil" "net/http/internal/testcert" "net/textproto" "net/url" "os" "reflect" "runtime" "strconv" "strings" "sync" "sync/atomic" "testing" "testing/iotest" "time" "golang.org/x/net/http/httpguts" ) // TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close // and then verify that the final 2 responses get errors back. // hostPortHandler writes back the client's "host:port". var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) { if r.FormValue("close") == "true" { w.Header().Set("Connection", "close") } w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close)) w.Write([]byte(r.RemoteAddr)) // Include the address of the net.Conn in addition to the RemoteAddr, // in case kernels reuse source ports quickly (see Issue 52450) if c, ok := ResponseWriterConnForTesting(w); ok { fmt.Fprintf(w, ", %T %p", c, c) } }) // testCloseConn is a net.Conn tracked by a testConnSet. type testCloseConn struct { net.Conn set *testConnSet } func (c *testCloseConn) Close() error { c.set.remove(c) return c.Conn.Close() } // testConnSet tracks a set of TCP connections and whether they've // been closed. type testConnSet struct { t *testing.T mu sync.Mutex // guards closed and list closed map[net.Conn]bool list []net.Conn // in order created } func (tcs *testConnSet) insert(c net.Conn) { tcs.mu.Lock() defer tcs.mu.Unlock() tcs.closed[c] = false tcs.list = append(tcs.list, c) } func (tcs *testConnSet) remove(c net.Conn) { tcs.mu.Lock() defer tcs.mu.Unlock() tcs.closed[c] = true } // some tests use this to manage raw tcp connections for later inspection func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) { connSet := &testConnSet{ t: t, closed: make(map[net.Conn]bool), } dial := func(n, addr string) (net.Conn, error) { c, err := net.Dial(n, addr) if err != nil { return nil, err } tc := &testCloseConn{c, connSet} connSet.insert(tc) return tc, nil } return connSet, dial } func (tcs *testConnSet) check(t *testing.T) { tcs.mu.Lock() defer tcs.mu.Unlock() for i := 4; i >= 0; i-- { for i, c := range tcs.list { if tcs.closed[c] { continue } if i != 0 { // TODO(bcmills): What is the Sleep here doing, and why is this // Unlock/Sleep/Lock cycle needed at all? tcs.mu.Unlock() time.Sleep(50 * time.Millisecond) tcs.mu.Lock() continue } t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list)) } } } func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) } func testReuseRequest(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("{}")) })).ts c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) res, err := c.Do(req) if err != nil { t.Fatal(err) } err = res.Body.Close() if err != nil { t.Fatal(err) } res, err = c.Do(req) if err != nil { t.Fatal(err) } err = res.Body.Close() if err != nil { t.Fatal(err) } } // Two subsequent requests and verify their response is the same. // The response from the server is our own IP:port func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) } func testTransportKeepAlives(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() for _, disableKeepAlive := range []bool{false, true} { c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive fetch := func(n int) string { res, err := c.Get(ts.URL) if err != nil { t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err) } body, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err) } return string(body) } body1 := fetch(1) body2 := fetch(2) bodiesDiffer := body1 != body2 if bodiesDiffer != disableKeepAlive { t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", disableKeepAlive, bodiesDiffer, body1, body2) } } } func TestTransportConnectionCloseOnResponse(t *testing.T) { run(t, testTransportConnectionCloseOnResponse) } func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, hostPortHandler).ts connSet, testDial := makeTestDial(t) c := ts.Client() tr := c.Transport.(*Transport) tr.Dial = testDial for _, connectionClose := range []bool{false, true} { fetch := func(n int) string { req := new(Request) var err error req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose)) if err != nil { t.Fatalf("URL parse error: %v", err) } req.Method = "GET" req.Proto = "HTTP/1.1" req.ProtoMajor = 1 req.ProtoMinor = 1 res, err := c.Do(req) if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) } defer res.Body.Close() body, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) } return string(body) } body1 := fetch(1) body2 := fetch(2) bodiesDiffer := body1 != body2 if bodiesDiffer != connectionClose { t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", connectionClose, bodiesDiffer, body1, body2) } tr.CloseIdleConnections() } connSet.check(t) } // TestTransportConnectionCloseOnRequest tests that the Transport's doesn't reuse // an underlying TCP connection after making an http.Request with Request.Close set. // // It tests the behavior by making an HTTP request to a server which // describes the source connection it got (remote port number + // address of its net.Conn). func TestTransportConnectionCloseOnRequest(t *testing.T) { run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode}) } func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, hostPortHandler).ts connSet, testDial := makeTestDial(t) c := ts.Client() tr := c.Transport.(*Transport) tr.Dial = testDial for _, reqClose := range []bool{false, true} { fetch := func(n int) string { req := new(Request) var err error req.URL, err = url.Parse(ts.URL) if err != nil { t.Fatalf("URL parse error: %v", err) } req.Method = "GET" req.Proto = "HTTP/1.1" req.ProtoMajor = 1 req.ProtoMinor = 1 req.Close = reqClose res, err := c.Do(req) if err != nil { t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err) } if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want { t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v", reqClose, got, !reqClose) } body, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err) } return string(body) } body1 := fetch(1) body2 := fetch(2) got := 1 if body1 != body2 { got++ } want := 1 if reqClose { want = 2 } if got != want { t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q", reqClose, got, want, body1, body2) } tr.CloseIdleConnections() } connSet.check(t) } // if the Transport's DisableKeepAlives is set, all requests should // send Connection: close. // HTTP/1-only (Connection: close doesn't exist in h2) func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode}) } func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() c.Transport.(*Transport).DisableKeepAlives = true res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } res.Body.Close() if res.Header.Get("X-Saw-Close") != "true" { t.Errorf("handler didn't see Connection: close ") } } // Test that Transport only sends one "Connection: close", regardless of // how "close" was indicated. func TestTransportRespectRequestWantsClose(t *testing.T) { run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode}) } func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) { tests := []struct { disableKeepAlives bool close bool }{ {disableKeepAlives: false, close: false}, {disableKeepAlives: false, close: true}, {disableKeepAlives: true, close: false}, {disableKeepAlives: true, close: true}, } for _, tc := range tests { t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close), func(t *testing.T) { ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) } count := 0 trace := &httptrace.ClientTrace{ WroteHeaderField: func(key string, field []string) { if key != "Connection" { return } if httpguts.HeaderValuesContainsToken(field, "close") { count += 1 } }, } req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) req.Close = tc.close res, err := c.Do(req) if err != nil { t.Fatal(err) } defer res.Body.Close() if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want { t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count) } }) } } func TestTransportIdleCacheKeys(t *testing.T) { run(t, testTransportIdleCacheKeys, []testMode{http1Mode}) } func testTransportIdleCacheKeys(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() tr := c.Transport.(*Transport) if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) } resp, err := c.Get(ts.URL) if err != nil { t.Error(err) } io.ReadAll(resp.Body) keys := tr.IdleConnKeysForTesting() if e, g := 1, len(keys); e != g { t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g) } if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e { t.Errorf("Expected idle cache key %q; got %q", e, keys[0]) } tr.CloseIdleConnections() if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) } } // Tests that the HTTP transport re-uses connections when a client // reads to the end of a response Body without closing it. func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) } func testTransportReadToEndReusesConn(t *testing.T, mode testMode) { const msg = "foobar" var addrSeen map[string]int ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { addrSeen[r.RemoteAddr]++ if r.URL.Path == "/chunked/" { w.WriteHeader(200) w.(Flusher).Flush() } else { w.Header().Set("Content-Length", strconv.Itoa(len(msg))) w.WriteHeader(200) } w.Write([]byte(msg)) })).ts for pi, path := range []string{"/content-length/", "/chunked/"} { wantLen := []int{len(msg), -1}[pi] addrSeen = make(map[string]int) for i := 0; i < 3; i++ { res, err := ts.Client().Get(ts.URL + path) if err != nil { t.Errorf("Get %s: %v", path, err) continue } // We want to close this body eventually (before the // defer afterTest at top runs), but not before the // len(addrSeen) check at the bottom of this test, // since Closing this early in the loop would risk // making connections be re-used for the wrong reason. defer res.Body.Close() if res.ContentLength != int64(wantLen) { t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen) } got, err := io.ReadAll(res.Body) if string(got) != msg || err != nil { t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg) } } if len(addrSeen) != 1 { t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen)) } } } func TestTransportMaxPerHostIdleConns(t *testing.T) { run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode}) } func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) { stop := make(chan struct{}) // stop marks the exit of main Test goroutine defer close(stop) resch := make(chan string) gotReq := make(chan bool) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotReq <- true var msg string select { case <-stop: return case msg = <-resch: } _, err := w.Write([]byte(msg)) if err != nil { t.Errorf("Write: %v", err) return } })).ts c := ts.Client() tr := c.Transport.(*Transport) maxIdleConnsPerHost := 2 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost // Start 3 outstanding requests and wait for the server to get them. // Their responses will hang until we write to resch, though. donech := make(chan bool) doReq := func() { defer func() { select { case <-stop: return case donech <- t.Failed(): } }() resp, err := c.Get(ts.URL) if err != nil { t.Error(err) return } if _, err := io.ReadAll(resp.Body); err != nil { t.Errorf("ReadAll: %v", err) return } } go doReq() <-gotReq go doReq() <-gotReq go doReq() <-gotReq if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g) } resch <- "res1" <-donech keys := tr.IdleConnKeysForTesting() if e, g := 1, len(keys); e != g { t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g) } addr := ts.Listener.Addr().String() cacheKey := "|http|" + addr if keys[0] != cacheKey { t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0]) } if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g { t.Errorf("after first response, expected %d idle conns; got %d", e, g) } resch <- "res2" <-donech if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w { t.Errorf("after second response, idle conns = %d; want %d", g, w) } resch <- "res3" <-donech if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w { t.Errorf("after third response, idle conns = %d; want %d", g, w) } } func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { run(t, testTransportMaxConnsPerHostIncludeDialInProgress) } func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("foo")) if err != nil { t.Fatalf("Write: %v", err) } })).ts c := ts.Client() tr := c.Transport.(*Transport) dialStarted := make(chan struct{}) stallDial := make(chan struct{}) tr.Dial = func(network, addr string) (net.Conn, error) { dialStarted <- struct{}{} <-stallDial return net.Dial(network, addr) } tr.DisableKeepAlives = true tr.MaxConnsPerHost = 1 preDial := make(chan struct{}) reqComplete := make(chan struct{}) doReq := func(reqId string) { req, _ := NewRequest("GET", ts.URL, nil) trace := &httptrace.ClientTrace{ GetConn: func(hostPort string) { preDial <- struct{}{} }, } req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) resp, err := tr.RoundTrip(req) if err != nil { t.Errorf("unexpected error for request %s: %v", reqId, err) } _, err = io.ReadAll(resp.Body) if err != nil { t.Errorf("unexpected error for request %s: %v", reqId, err) } reqComplete <- struct{}{} } // get req1 to dial-in-progress go doReq("req1") <-preDial <-dialStarted // get req2 to waiting on conns per host to go down below max go doReq("req2") <-preDial select { case <-dialStarted: t.Error("req2 dial started while req1 dial in progress") return default: } // let req1 complete stallDial <- struct{}{} <-reqComplete // let req2 complete <-dialStarted stallDial <- struct{}{} <-reqComplete } func TestTransportMaxConnsPerHost(t *testing.T) { run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode}) } func testTransportMaxConnsPerHost(t *testing.T, mode testMode) { CondSkipHTTP2(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("foo")) if err != nil { t.Fatalf("Write: %v", err) } }) ts := newClientServerTest(t, mode, h).ts c := ts.Client() tr := c.Transport.(*Transport) tr.MaxConnsPerHost = 1 mu := sync.Mutex{} var conns []net.Conn var dialCnt, gotConnCnt, tlsHandshakeCnt int32 tr.Dial = func(network, addr string) (net.Conn, error) { atomic.AddInt32(&dialCnt, 1) c, err := net.Dial(network, addr) mu.Lock() defer mu.Unlock() conns = append(conns, c) return c, err } doReq := func() { trace := &httptrace.ClientTrace{ GotConn: func(connInfo httptrace.GotConnInfo) { if !connInfo.Reused { atomic.AddInt32(&gotConnCnt, 1) } }, TLSHandshakeStart: func() { atomic.AddInt32(&tlsHandshakeCnt, 1) }, } req, _ := NewRequest("GET", ts.URL, nil) req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) resp, err := c.Do(req) if err != nil { t.Fatalf("request failed: %v", err) } defer resp.Body.Close() _, err = io.ReadAll(resp.Body) if err != nil { t.Fatalf("read body failed: %v", err) } } wg := sync.WaitGroup{} for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() doReq() }() } wg.Wait() expected := int32(tr.MaxConnsPerHost) if dialCnt != expected { t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected) } if gotConnCnt != expected { t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected) } if ts.TLS != nil && tlsHandshakeCnt != expected { t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected) } if t.Failed() { t.FailNow() } mu.Lock() for _, c := range conns { c.Close() } conns = nil mu.Unlock() tr.CloseIdleConnections() doReq() expected++ if dialCnt != expected { t.Errorf("round 2: too many dials: %d", dialCnt) } if gotConnCnt != expected { t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected) } if ts.TLS != nil && tlsHandshakeCnt != expected { t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected) } } func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) { run(t, testTransportMaxConnsPerHostDialCancellation, testNotParallel, // because test uses SetPendingDialHooks []testMode{http1Mode, https1Mode, http2Mode}, ) } func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) { CondSkipHTTP2(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("foo")) if err != nil { t.Fatalf("Write: %v", err) } }) cst := newClientServerTest(t, mode, h) defer cst.close() ts := cst.ts c := ts.Client() tr := c.Transport.(*Transport) tr.MaxConnsPerHost = 1 // This request is cancelled when dial is queued, which preempts dialing. ctx, cancel := context.WithCancel(context.Background()) defer cancel() SetPendingDialHooks(cancel, nil) defer SetPendingDialHooks(nil, nil) req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil) _, err := c.Do(req) if !errors.Is(err, context.Canceled) { t.Errorf("expected error %v, got %v", context.Canceled, err) } // This request should succeed. SetPendingDialHooks(nil, nil) req, _ = NewRequest("GET", ts.URL, nil) resp, err := c.Do(req) if err != nil { t.Fatalf("request failed: %v", err) } defer resp.Body.Close() _, err = io.ReadAll(resp.Body) if err != nil { t.Fatalf("read body failed: %v", err) } } func TestTransportRemovesDeadIdleConnections(t *testing.T) { run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode}) } func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, r.RemoteAddr) })).ts c := ts.Client() tr := c.Transport.(*Transport) doReq := func(name string) { // Do a POST instead of a GET to prevent the Transport's // idempotent request retry logic from kicking in... res, err := c.Post(ts.URL, "", nil) if err != nil { t.Fatalf("%s: %v", name, err) } if res.StatusCode != 200 { t.Fatalf("%s: %v", name, res.Status) } defer res.Body.Close() slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("%s: %v", name, err) } t.Logf("%s: ok (%q)", name, slurp) } doReq("first") keys1 := tr.IdleConnKeysForTesting() ts.CloseClientConnections() var keys2 []string waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { keys2 = tr.IdleConnKeysForTesting() if len(keys2) != 0 { if d > 0 { t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2) } return false } return true }) doReq("second") } // Test that the Transport notices when a server hangs up on its // unexpectedly (a keep-alive connection is closed). func TestTransportServerClosingUnexpectedly(t *testing.T) { run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode}) } func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() fetch := func(n, retries int) string { condFatalf := func(format string, arg ...any) { if retries <= 0 { t.Fatalf(format, arg...) } t.Logf("retrying shortly after expected error: "+format, arg...) time.Sleep(time.Second / time.Duration(retries)) } for retries >= 0 { retries-- res, err := c.Get(ts.URL) if err != nil { condFatalf("error in req #%d, GET: %v", n, err) continue } body, err := io.ReadAll(res.Body) if err != nil { condFatalf("error in req #%d, ReadAll: %v", n, err) continue } res.Body.Close() return string(body) } panic("unreachable") } body1 := fetch(1, 0) body2 := fetch(2, 0) // Close all the idle connections in a way that's similar to // the server hanging up on us. We don't use // httptest.Server.CloseClientConnections because it's // best-effort and stops blocking after 5 seconds. On a loaded // machine running many tests concurrently it's possible for // that method to be async and cause the body3 fetch below to // run on an old connection. This function is synchronous. ExportCloseTransportConnsAbruptly(c.Transport.(*Transport)) body3 := fetch(3, 5) if body1 != body2 { t.Errorf("expected body1 and body2 to be equal") } if body2 == body3 { t.Errorf("expected body2 and body3 to be different") } } // Test for https://golang.org/issue/2616 (appropriate issue number) // This fails pretty reliably with GOMAXPROCS=100 or something high. func TestStressSurpriseServerCloses(t *testing.T) { run(t, testStressSurpriseServerCloses, []testMode{http1Mode}) } func testStressSurpriseServerCloses(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping test in short mode") } ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "5") w.Header().Set("Content-Type", "text/plain") w.Write([]byte("Hello")) w.(Flusher).Flush() conn, buf, _ := w.(Hijacker).Hijack() buf.Flush() conn.Close() })).ts c := ts.Client() // Do a bunch of traffic from different goroutines. Send to activityc // after each request completes, regardless of whether it failed. // If these are too high, OS X exhausts its ephemeral ports // and hangs waiting for them to transition TCP states. That's // not what we want to test. TODO(bradfitz): use an io.Pipe // dialer for this test instead? const ( numClients = 20 reqsPerClient = 25 ) var wg sync.WaitGroup wg.Add(numClients * reqsPerClient) for i := 0; i < numClients; i++ { go func() { for i := 0; i < reqsPerClient; i++ { res, err := c.Get(ts.URL) if err == nil { // We expect errors since the server is // hanging up on us after telling us to // send more requests, so we don't // actually care what the error is. // But we want to close the body in cases // where we won the race. res.Body.Close() } wg.Done() } }() } // Make sure all the request come back, one way or another. wg.Wait() } // TestTransportHeadResponses verifies that we deal with Content-Lengths // with no bodies properly func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) } func testTransportHeadResponses(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) } w.Header().Set("Content-Length", "123") w.WriteHeader(200) })).ts c := ts.Client() for i := 0; i < 2; i++ { res, err := c.Head(ts.URL) if err != nil { t.Errorf("error on loop %d: %v", i, err) continue } if e, g := "123", res.Header.Get("Content-Length"); e != g { t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) } if e, g := int64(123), res.ContentLength; e != g { t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) } if all, err := io.ReadAll(res.Body); err != nil { t.Errorf("loop %d: Body ReadAll: %v", i, err) } else if len(all) != 0 { t.Errorf("Bogus body %q", all) } } } // TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding // on responses to HEAD requests. func TestTransportHeadChunkedResponse(t *testing.T) { run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel) } func testTransportHeadChunkedResponse(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) } w.Header().Set("Transfer-Encoding", "chunked") // client should ignore w.Header().Set("x-client-ipport", r.RemoteAddr) w.WriteHeader(200) })).ts c := ts.Client() // Ensure that we wait for the readLoop to complete before // calling Head again didRead := make(chan bool) SetReadLoopBeforeNextReadHook(func() { didRead <- true }) defer SetReadLoopBeforeNextReadHook(nil) res1, err := c.Head(ts.URL) <-didRead if err != nil { t.Fatalf("request 1 error: %v", err) } res2, err := c.Head(ts.URL) <-didRead if err != nil { t.Fatalf("request 2 error: %v", err) } if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 { t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2) } } var roundTripTests = []struct { accept string expectAccept string compressed bool }{ // Requests with no accept-encoding header use transparent compression {"", "gzip", false}, // Requests with other accept-encoding should pass through unmodified {"foo", "foo", false}, // Requests with accept-encoding == gzip should be passed through {"gzip", "gzip", true}, } // Test that the modification made to the Request by the RoundTripper is cleaned up func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) } func testRoundTripGzip(t *testing.T, mode testMode) { const responseBody = "test response body" ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { accept := req.Header.Get("Accept-Encoding") if expect := req.FormValue("expect_accept"); accept != expect { t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q", req.FormValue("testnum"), accept, expect) } if accept == "gzip" { rw.Header().Set("Content-Encoding", "gzip") gz := gzip.NewWriter(rw) gz.Write([]byte(responseBody)) gz.Close() } else { rw.Header().Set("Content-Encoding", accept) rw.Write([]byte(responseBody)) } })).ts tr := ts.Client().Transport.(*Transport) for i, test := range roundTripTests { // Test basic request (no accept-encoding) req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil) if test.accept != "" { req.Header.Set("Accept-Encoding", test.accept) } res, err := tr.RoundTrip(req) if err != nil { t.Errorf("%d. RoundTrip: %v", i, err) continue } var body []byte if test.compressed { var r *gzip.Reader r, err = gzip.NewReader(res.Body) if err != nil { t.Errorf("%d. gzip NewReader: %v", i, err) continue } body, err = io.ReadAll(r) res.Body.Close() } else { body, err = io.ReadAll(res.Body) } if err != nil { t.Errorf("%d. Error: %q", i, err) continue } if g, e := string(body), responseBody; g != e { t.Errorf("%d. body = %q; want %q", i, g, e) } if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e { t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e) } if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e { t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e) } } } func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) } func testTransportGzip(t *testing.T, mode testMode) { if mode == http2Mode { t.Skip("https://go.dev/issue/56020") } const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" const nRandBytes = 1024 * 1024 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { if req.Method == "HEAD" { if g := req.Header.Get("Accept-Encoding"); g != "" { t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g) } return } if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e { t.Errorf("Accept-Encoding = %q, want %q", g, e) } rw.Header().Set("Content-Encoding", "gzip") var w io.Writer = rw var buf bytes.Buffer if req.FormValue("chunked") == "0" { w = &buf defer io.Copy(rw, &buf) defer func() { rw.Header().Set("Content-Length", strconv.Itoa(buf.Len())) }() } gz := gzip.NewWriter(w) gz.Write([]byte(testString)) if req.FormValue("body") == "large" { io.CopyN(gz, rand.Reader, nRandBytes) } gz.Close() })).ts c := ts.Client() for _, chunked := range []string{"1", "0"} { // First fetch something large, but only read some of it. res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked) if err != nil { t.Fatalf("large get: %v", err) } buf := make([]byte, len(testString)) n, err := io.ReadFull(res.Body, buf) if err != nil { t.Fatalf("partial read of large response: size=%d, %v", n, err) } if e, g := testString, string(buf); e != g { t.Errorf("partial read got %q, expected %q", g, e) } res.Body.Close() // Read on the body, even though it's closed n, err = res.Body.Read(buf) if n != 0 || err == nil { t.Errorf("expected error post-closed large Read; got = %d, %v", n, err) } // Then something small. res, err = c.Get(ts.URL + "/?chunked=" + chunked) if err != nil { t.Fatal(err) } body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } if g, e := string(body), testString; g != e { t.Fatalf("body = %q; want %q", g, e) } if g, e := res.Header.Get("Content-Encoding"), ""; g != e { t.Fatalf("Content-Encoding = %q; want %q", g, e) } // Read on the body after it's been fully read: n, err = res.Body.Read(buf) if n != 0 || err == nil { t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err) } res.Body.Close() n, err = res.Body.Read(buf) if n != 0 || err == nil { t.Errorf("expected Read error after Close; got %d, %v", n, err) } } // And a HEAD request too, because they're always weird. res, err := c.Head(ts.URL) if err != nil { t.Fatalf("Head: %v", err) } if res.StatusCode != 200 { t.Errorf("Head status=%d; want=200", res.StatusCode) } } // If a request has Expect:100-continue header, the request blocks sending body until the first response. // Premature consumption of the request body should not be occurred. func TestTransportExpect100Continue(t *testing.T) { run(t, testTransportExpect100Continue, []testMode{http1Mode}) } func testTransportExpect100Continue(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { switch req.URL.Path { case "/100": // This endpoint implicitly responds 100 Continue and reads body. if _, err := io.Copy(io.Discard, req.Body); err != nil { t.Error("Failed to read Body", err) } rw.WriteHeader(StatusOK) case "/200": // Go 1.5 adds Connection: close header if the client expect // continue but not entire request body is consumed. rw.WriteHeader(StatusOK) case "/500": rw.WriteHeader(StatusInternalServerError) case "/keepalive": // This hijacked endpoint responds error without Connection:close. _, bufrw, err := rw.(Hijacker).Hijack() if err != nil { log.Fatal(err) } bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n") bufrw.WriteString("Content-Length: 0\r\n\r\n") bufrw.Flush() case "/timeout": // This endpoint tries to read body without 100 (Continue) response. // After ExpectContinueTimeout, the reading will be started. conn, bufrw, err := rw.(Hijacker).Hijack() if err != nil { log.Fatal(err) } if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil { t.Error("Failed to read Body", err) } bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n") bufrw.Flush() conn.Close() } })).ts tests := []struct { path string body []byte sent int status int }{ {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent. {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent. {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent. {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Although without Connection:close, body isn't sent. {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent. } c := ts.Client() for i, v := range tests { tr := &Transport{ ExpectContinueTimeout: 2 * time.Second, } defer tr.CloseIdleConnections() c.Transport = tr body := bytes.NewReader(v.body) req, err := NewRequest("PUT", ts.URL+v.path, body) if err != nil { t.Fatal(err) } req.Header.Set("Expect", "100-continue") req.ContentLength = int64(len(v.body)) resp, err := c.Do(req) if err != nil { t.Fatal(err) } resp.Body.Close() sent := len(v.body) - body.Len() if v.status != resp.StatusCode { t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path) } if v.sent != sent { t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path) } } } func TestSOCKS5Proxy(t *testing.T) { run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode}) } func testSOCKS5Proxy(t *testing.T, mode testMode) { ch := make(chan string, 1) l := newLocalListener(t) defer l.Close() defer close(ch) proxy := func(t *testing.T) { s, err := l.Accept() if err != nil { t.Errorf("socks5 proxy Accept(): %v", err) return } defer s.Close() var buf [22]byte if _, err := io.ReadFull(s, buf[:3]); err != nil { t.Errorf("socks5 proxy initial read: %v", err) return } if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want) return } if _, err := s.Write([]byte{5, 0}); err != nil { t.Errorf("socks5 proxy initial write: %v", err) return } if _, err := io.ReadFull(s, buf[:4]); err != nil { t.Errorf("socks5 proxy second read: %v", err) return } if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want) return } var ipLen int switch buf[3] { case 1: ipLen = net.IPv4len case 4: ipLen = net.IPv6len default: t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4]) return } if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil { t.Errorf("socks5 proxy address read: %v", err) return } ip := net.IP(buf[4 : ipLen+4]) port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6]) copy(buf[:3], []byte{5, 0, 0}) if _, err := s.Write(buf[:ipLen+6]); err != nil { t.Errorf("socks5 proxy connect write: %v", err) return } ch <- fmt.Sprintf("proxy for %s:%d", ip, port) // Implement proxying. targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port))) targetConn, err := net.Dial("tcp", targetHost) if err != nil { t.Errorf("net.Dial failed") return } go io.Copy(targetConn, s) io.Copy(s, targetConn) // Wait for the client to close the socket. targetConn.Close() } pu, err := url.Parse("socks5://" + l.Addr().String()) if err != nil { t.Fatal(err) } sentinelHeader := "X-Sentinel" sentinelValue := "12345" h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set(sentinelHeader, sentinelValue) }) for _, useTLS := range []bool{false, true} { t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) { ts := newClientServerTest(t, mode, h).ts go proxy(t) c := ts.Client() c.Transport.(*Transport).Proxy = ProxyURL(pu) r, err := c.Head(ts.URL) if err != nil { t.Fatal(err) } if r.Header.Get(sentinelHeader) != sentinelValue { t.Errorf("Failed to retrieve sentinel value") } got := <-ch ts.Close() tsu, err := url.Parse(ts.URL) if err != nil { t.Fatal(err) } want := "proxy for " + tsu.Host if got != want { t.Errorf("got %q, want %q", got, want) } }) } } func TestTransportProxy(t *testing.T) { defer afterTest(t) testCases := []struct{ siteMode, proxyMode testMode }{ {http1Mode, http1Mode}, {http1Mode, https1Mode}, {https1Mode, http1Mode}, {https1Mode, https1Mode}, } for _, testCase := range testCases { siteMode := testCase.siteMode proxyMode := testCase.proxyMode t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) { siteCh := make(chan *Request, 1) h1 := HandlerFunc(func(w ResponseWriter, r *Request) { siteCh <- r }) proxyCh := make(chan *Request, 1) h2 := HandlerFunc(func(w ResponseWriter, r *Request) { proxyCh <- r // Implement an entire CONNECT proxy if r.Method == "CONNECT" { hijacker, ok := w.(Hijacker) if !ok { t.Errorf("hijack not allowed") return } clientConn, _, err := hijacker.Hijack() if err != nil { t.Errorf("hijacking failed") return } res := &Response{ StatusCode: StatusOK, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Header: make(Header), } targetConn, err := net.Dial("tcp", r.URL.Host) if err != nil { t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) return } if err := res.Write(clientConn); err != nil { t.Errorf("Writing 200 OK failed: %v", err) return } go io.Copy(targetConn, clientConn) go func() { io.Copy(clientConn, targetConn) targetConn.Close() }() } }) ts := newClientServerTest(t, siteMode, h1).ts proxy := newClientServerTest(t, proxyMode, h2).ts pu, err := url.Parse(proxy.URL) if err != nil { t.Fatal(err) } // If neither server is HTTPS or both are, then c may be derived from either. // If only one server is HTTPS, c must be derived from that server in order // to ensure that it is configured to use the fake root CA from testcert.go. c := proxy.Client() if siteMode == https1Mode { c = ts.Client() } c.Transport.(*Transport).Proxy = ProxyURL(pu) if _, err := c.Head(ts.URL); err != nil { t.Error(err) } got := <-proxyCh c.Transport.(*Transport).CloseIdleConnections() ts.Close() proxy.Close() if siteMode == https1Mode { // First message should be a CONNECT, asking for a socket to the real server, if got.Method != "CONNECT" { t.Errorf("Wrong method for secure proxying: %q", got.Method) } gotHost := got.URL.Host pu, err := url.Parse(ts.URL) if err != nil { t.Fatal("Invalid site URL") } if wantHost := pu.Host; gotHost != wantHost { t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost) } // The next message on the channel should be from the site's server. next := <-siteCh if next.Method != "HEAD" { t.Errorf("Wrong method at destination: %s", next.Method) } if nextURL := next.URL.String(); nextURL != "/" { t.Errorf("Wrong URL at destination: %s", nextURL) } } else { if got.Method != "HEAD" { t.Errorf("Wrong method for destination: %q", got.Method) } gotURL := got.URL.String() wantURL := ts.URL + "/" if gotURL != wantURL { t.Errorf("Got URL %q, want %q", gotURL, wantURL) } } }) } } func TestOnProxyConnectResponse(t *testing.T) { var tcases = []struct { proxyStatusCode int err error }{ { StatusOK, nil, }, { StatusForbidden, errors.New("403"), }, } for _, tcase := range tcases { h1 := HandlerFunc(func(w ResponseWriter, r *Request) { }) h2 := HandlerFunc(func(w ResponseWriter, r *Request) { // Implement an entire CONNECT proxy if r.Method == "CONNECT" { if tcase.proxyStatusCode != StatusOK { w.WriteHeader(tcase.proxyStatusCode) return } hijacker, ok := w.(Hijacker) if !ok { t.Errorf("hijack not allowed") return } clientConn, _, err := hijacker.Hijack() if err != nil { t.Errorf("hijacking failed") return } res := &Response{ StatusCode: StatusOK, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Header: make(Header), } targetConn, err := net.Dial("tcp", r.URL.Host) if err != nil { t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) return } if err := res.Write(clientConn); err != nil { t.Errorf("Writing 200 OK failed: %v", err) return } go io.Copy(targetConn, clientConn) go func() { io.Copy(clientConn, targetConn) targetConn.Close() }() } }) ts := newClientServerTest(t, https1Mode, h1).ts proxy := newClientServerTest(t, https1Mode, h2).ts pu, err := url.Parse(proxy.URL) if err != nil { t.Fatal(err) } c := proxy.Client() c.Transport.(*Transport).Proxy = ProxyURL(pu) c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error { if proxyURL.String() != pu.String() { t.Errorf("proxy url got %s, want %s", proxyURL, pu) } if "https://"+connectReq.URL.String() != ts.URL { t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL) } return tcase.err } if _, err := c.Head(ts.URL); err != nil { if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) { t.Errorf("got %v, want %v", err, tcase.err) } } } } // Issue 28012: verify that the Transport closes its TCP connection to http proxies // when they're slow to reply to HTTPS CONNECT responses. func TestTransportProxyHTTPSConnectLeak(t *testing.T) { setParallel(t) defer afterTest(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() ln := newLocalListener(t) defer ln.Close() listenerDone := make(chan struct{}) go func() { defer close(listenerDone) c, err := ln.Accept() if err != nil { t.Errorf("Accept: %v", err) return } defer c.Close() // Read the CONNECT request br := bufio.NewReader(c) cr, err := ReadRequest(br) if err != nil { t.Errorf("proxy server failed to read CONNECT request") return } if cr.Method != "CONNECT" { t.Errorf("unexpected method %q", cr.Method) return } // Now hang and never write a response; instead, cancel the request and wait // for the client to close. // (Prior to Issue 28012 being fixed, we never closed.) cancel() var buf [1]byte _, err = br.Read(buf[:]) if err != io.EOF { t.Errorf("proxy server Read err = %v; want EOF", err) } return }() c := &Client{ Transport: &Transport{ Proxy: func(*Request) (*url.URL, error) { return url.Parse("http://" + ln.Addr().String()) }, }, } req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil) if err != nil { t.Fatal(err) } _, err = c.Do(req) if err == nil { t.Errorf("unexpected Get success") } // Wait unconditionally for the listener goroutine to exit: this should never // hang, so if it does we want a full goroutine dump — and that's exactly what // the testing package will give us when the test run times out. <-listenerDone } // Issue 16997: test transport dial preserves typed errors func TestTransportDialPreservesNetOpProxyError(t *testing.T) { defer afterTest(t) var errDial = errors.New("some dial error") tr := &Transport{ Proxy: func(*Request) (*url.URL, error) { return url.Parse("http://proxy.fake.tld/") }, Dial: func(string, string) (net.Conn, error) { return nil, errDial }, } defer tr.CloseIdleConnections() c := &Client{Transport: tr} req, _ := NewRequest("GET", "http://fake.tld", nil) res, err := c.Do(req) if err == nil { res.Body.Close() t.Fatal("wanted a non-nil error") } uerr, ok := err.(*url.Error) if !ok { t.Fatalf("got %T, want *url.Error", err) } oe, ok := uerr.Err.(*net.OpError) if !ok { t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err) } want := &net.OpError{ Op: "proxyconnect", Net: "tcp", Err: errDial, // original error, unwrapped. } if !reflect.DeepEqual(oe, want) { t.Errorf("Got error %#v; want %#v", oe, want) } } // Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader. // // (A bug caused dialConn to instead write the per-request Proxy-Authorization // header through to the shared Header instance, introducing a data race.) func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) { run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader) } func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) { proxy := newClientServerTest(t, mode, NotFoundHandler()).ts defer proxy.Close() c := proxy.Client() tr := c.Transport.(*Transport) tr.Proxy = func(*Request) (*url.URL, error) { u, _ := url.Parse(proxy.URL) u.User = url.UserPassword("aladdin", "opensesame") return u, nil } h := tr.ProxyConnectHeader if h == nil { h = make(Header) } tr.ProxyConnectHeader = h.Clone() req, err := NewRequest("GET", "https://golang.fake.tld/", nil) if err != nil { t.Fatal(err) } _, err = c.Do(req) if err == nil { t.Errorf("unexpected Get success") } if !reflect.DeepEqual(tr.ProxyConnectHeader, h) { t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h) } } // TestTransportGzipRecursive sends a gzip quine and checks that the // client gets the same value back. This is more cute than anything, // but checks that we don't recurse forever, and checks that // Content-Encoding is removed. func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) } func testTransportGzipRecursive(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "gzip") w.Write(rgz) })).ts c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } if !bytes.Equal(body, rgz) { t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x", body, rgz) } if g, e := res.Header.Get("Content-Encoding"), ""; g != e { t.Fatalf("Content-Encoding = %q; want %q", g, e) } } // golang.org/issue/7750: request fails when server replies with // a short gzip body func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) } func testTransportGzipShort(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "gzip") w.Write([]byte{0x1f, 0x8b}) })).ts c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } defer res.Body.Close() _, err = io.ReadAll(res.Body) if err == nil { t.Fatal("Expect an error from reading a body.") } if err != io.ErrUnexpectedEOF { t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err) } } // Wait until number of goroutines is no greater than nmax, or time out. func waitNumGoroutine(nmax int) int { nfinal := runtime.NumGoroutine() for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- { time.Sleep(50 * time.Millisecond) runtime.GC() nfinal = runtime.NumGoroutine() } return nfinal } // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { run(t, testTransportPersistConnLeak, testNotParallel) } func testTransportPersistConnLeak(t *testing.T, mode testMode) { if mode == http2Mode { t.Skip("flaky in HTTP/2") } // Not parallel: counts goroutines const numReq = 25 gotReqCh := make(chan bool, numReq) unblockCh := make(chan bool, numReq) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotReqCh <- true <-unblockCh w.Header().Set("Content-Length", "0") w.WriteHeader(204) })).ts c := ts.Client() tr := c.Transport.(*Transport) n0 := runtime.NumGoroutine() didReqCh := make(chan bool, numReq) failed := make(chan bool, numReq) for i := 0; i < numReq; i++ { go func() { res, err := c.Get(ts.URL) didReqCh <- true if err != nil { t.Logf("client fetch error: %v", err) failed <- true return } res.Body.Close() }() } // Wait for all goroutines to be stuck in the Handler. for i := 0; i < numReq; i++ { select { case <-gotReqCh: // ok case <-failed: // Not great but not what we are testing: // sometimes an overloaded system will fail to make all the connections. } } nhigh := runtime.NumGoroutine() // Tell all handlers to unblock and reply. close(unblockCh) // Wait for all HTTP clients to be done. for i := 0; i < numReq; i++ { <-didReqCh } tr.CloseIdleConnections() nfinal := waitNumGoroutine(n0 + 5) growth := nfinal - n0 // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. // Previously we were leaking one per numReq. if int(growth) > 5 { t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) t.Error("too many new goroutines") } } // golang.org/issue/4531: Transport leaks goroutines when // request.ContentLength is explicitly short func TestTransportPersistConnLeakShortBody(t *testing.T) { run(t, testTransportPersistConnLeakShortBody, testNotParallel) } func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) { if mode == http2Mode { t.Skip("flaky in HTTP/2") } // Not parallel: measures goroutines. ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { })).ts c := ts.Client() tr := c.Transport.(*Transport) n0 := runtime.NumGoroutine() body := []byte("Hello") for i := 0; i < 20; i++ { req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) if err != nil { t.Fatal(err) } req.ContentLength = int64(len(body) - 2) // explicitly short _, err = c.Do(req) if err == nil { t.Fatal("Expect an error from writing too long of a body.") } } nhigh := runtime.NumGoroutine() tr.CloseIdleConnections() nfinal := waitNumGoroutine(n0 + 5) growth := nfinal - n0 // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. // Previously we were leaking one per numReq. t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) if int(growth) > 5 { t.Error("too many new goroutines") } } // A countedConn is a net.Conn that decrements an atomic counter when finalized. type countedConn struct { net.Conn } // A countingDialer dials connections and counts the number that remain reachable. type countingDialer struct { dialer net.Dialer mu sync.Mutex total, live int64 } func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { conn, err := d.dialer.DialContext(ctx, network, address) if err != nil { return nil, err } counted := new(countedConn) counted.Conn = conn d.mu.Lock() defer d.mu.Unlock() d.total++ d.live++ runtime.SetFinalizer(counted, d.decrement) return counted, nil } func (d *countingDialer) decrement(*countedConn) { d.mu.Lock() defer d.mu.Unlock() d.live-- } func (d *countingDialer) Read() (total, live int64) { d.mu.Lock() defer d.mu.Unlock() return d.total, d.live } func TestTransportPersistConnLeakNeverIdle(t *testing.T) { run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode}) } func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // Close every connection so that it cannot be kept alive. conn, _, err := w.(Hijacker).Hijack() if err != nil { t.Errorf("Hijack failed unexpectedly: %v", err) return } conn.Close() })).ts var d countingDialer c := ts.Client() c.Transport.(*Transport).DialContext = d.DialContext body := []byte("Hello") for i := 0; ; i++ { total, live := d.Read() if live < total { break } if i >= 1<<12 { t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i) } req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) if err != nil { t.Fatal(err) } _, err = c.Do(req) if err == nil { t.Fatal("expected broken connection") } runtime.GC() } } type countedContext struct { context.Context } type contextCounter struct { mu sync.Mutex live int64 } func (cc *contextCounter) Track(ctx context.Context) context.Context { counted := new(countedContext) counted.Context = ctx cc.mu.Lock() defer cc.mu.Unlock() cc.live++ runtime.SetFinalizer(counted, cc.decrement) return counted } func (cc *contextCounter) decrement(*countedContext) { cc.mu.Lock() defer cc.mu.Unlock() cc.live-- } func (cc *contextCounter) Read() (live int64) { cc.mu.Lock() defer cc.mu.Unlock() return cc.live } func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) { run(t, testTransportPersistConnContextLeakMaxConnsPerHost) } func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) { if mode == http2Mode { t.Skip("https://go.dev/issue/56021") } ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { runtime.Gosched() w.WriteHeader(StatusOK) })).ts c := ts.Client() c.Transport.(*Transport).MaxConnsPerHost = 1 ctx := context.Background() body := []byte("Hello") doPosts := func(cc *contextCounter) { var wg sync.WaitGroup for n := 64; n > 0; n-- { wg.Add(1) go func() { defer wg.Done() ctx := cc.Track(ctx) req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) if err != nil { t.Error(err) } _, err = c.Do(req.WithContext(ctx)) if err != nil { t.Errorf("Do failed with error: %v", err) } }() } wg.Wait() } var initialCC contextCounter doPosts(&initialCC) // flushCC exists only to put pressure on the GC to finalize the initialCC // contexts: the flushCC allocations should eventually displace the initialCC // allocations. var flushCC contextCounter for i := 0; ; i++ { live := initialCC.Read() if live == 0 { break } if i >= 100 { t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i) } doPosts(&flushCC) runtime.GC() } } // This used to crash; https://golang.org/issue/3266 func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) } func testTransportIdleConnCrash(t *testing.T, mode testMode) { var tr *Transport unblockCh := make(chan bool, 1) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockCh tr.CloseIdleConnections() })).ts c := ts.Client() tr = c.Transport.(*Transport) didreq := make(chan bool) go func() { res, err := c.Get(ts.URL) if err != nil { t.Error(err) } else { res.Body.Close() // returns idle conn } didreq <- true }() unblockCh <- true <-didreq } // Test that the transport doesn't close the TCP connection early, // before the response body has been read. This was a regression // which sadly lacked a triggering test. The large response body made // the old race easier to trigger. func TestIssue3644(t *testing.T) { run(t, testIssue3644) } func testIssue3644(t *testing.T, mode testMode) { const numFoos = 5000 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") for i := 0; i < numFoos; i++ { w.Write([]byte("foo ")) } })).ts c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } defer res.Body.Close() bs, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } if len(bs) != numFoos*len("foo ") { t.Errorf("unexpected response length") } } // Test that a client receives a server's reply, even if the server doesn't read // the entire request body. func TestIssue3595(t *testing.T) { // Not parallel: modifies the global rstAvoidanceDelay. run(t, testIssue3595, testNotParallel) } func testIssue3595(t *testing.T, mode testMode) { runTimeSensitiveTest(t, []time.Duration{ 1 * time.Millisecond, 5 * time.Millisecond, 10 * time.Millisecond, 50 * time.Millisecond, 100 * time.Millisecond, 500 * time.Millisecond, time.Second, 5 * time.Second, }, func(t *testing.T, timeout time.Duration) error { SetRSTAvoidanceDelay(t, timeout) t.Logf("set RST avoidance delay to %v", timeout) const deniedMsg = "sorry, denied." cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, deniedMsg, StatusUnauthorized) })) // We need to close cst explicitly here so that in-flight server // requests don't race with the call to SetRSTAvoidanceDelay for a retry. defer cst.close() ts := cst.ts c := ts.Client() res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) if err != nil { return fmt.Errorf("Post: %v", err) } got, err := io.ReadAll(res.Body) if err != nil { return fmt.Errorf("Body ReadAll: %v", err) } t.Logf("server response:\n%s", got) if !strings.Contains(string(got), deniedMsg) { // If we got an RST packet too early, we should have seen an error // from io.ReadAll, not a silently-truncated body. t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg) } return nil }) } // From https://golang.org/issue/4454 , // "client fails to handle requests with no body and chunked encoding" func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) } func testChunkedNoContent(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNoContent) })).ts c := ts.Client() for _, closeBody := range []bool{true, false} { const n = 4 for i := 1; i <= n; i++ { res, err := c.Get(ts.URL) if err != nil { t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err) } else { if closeBody { res.Body.Close() } } } } } func TestTransportConcurrency(t *testing.T) { run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode}) } func testTransportConcurrency(t *testing.T, mode testMode) { // Not parallel: uses global test hooks. maxProcs, numReqs := 16, 500 if testing.Short() { maxProcs, numReqs = 4, 50 } defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%v", r.FormValue("echo")) })).ts var wg sync.WaitGroup wg.Add(numReqs) // Due to the Transport's "socket late binding" (see // idleConnCh in transport.go), the numReqs HTTP requests // below can finish with a dial still outstanding. To keep // the leak checker happy, keep track of pending dials and // wait for them to finish (and be closed or returned to the // idle pool) before we close idle connections. SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) defer SetPendingDialHooks(nil, nil) c := ts.Client() reqs := make(chan string) defer close(reqs) for i := 0; i < maxProcs*2; i++ { go func() { for req := range reqs { res, err := c.Get(ts.URL + "/?echo=" + req) if err != nil { if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") { // https://go.dev/issue/52168: this test was observed to fail with // ECONNRESET errors in Dial on various netbsd builders. t.Logf("error on req %s: %v", req, err) t.Logf("(see https://go.dev/issue/52168)") } else { t.Errorf("error on req %s: %v", req, err) } wg.Done() continue } all, err := io.ReadAll(res.Body) if err != nil { t.Errorf("read error on req %s: %v", req, err) } else if string(all) != req { t.Errorf("body of req %s = %q; want %q", req, all, req) } res.Body.Close() wg.Done() } }() } for i := 0; i < numReqs; i++ { reqs <- fmt.Sprintf("request-%d", i) } wg.Wait() } func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) } func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) { mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { io.Copy(w, neverEnding('a')) }) ts := newClientServerTest(t, mode, mux).ts connc := make(chan net.Conn, 1) c := ts.Client() c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { conn, err := net.Dial(n, addr) if err != nil { return nil, err } select { case connc <- conn: default: } return conn, nil } res, err := c.Get(ts.URL + "/get") if err != nil { t.Fatalf("Error issuing GET: %v", err) } defer res.Body.Close() conn := <-connc conn.SetDeadline(time.Now().Add(1 * time.Millisecond)) _, err = io.Copy(io.Discard, res.Body) if err == nil { t.Errorf("Unexpected successful copy") } } func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode}) } func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) { const debug = false mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { io.Copy(w, neverEnding('a')) }) mux.HandleFunc("/put", func(w ResponseWriter, r *Request) { defer r.Body.Close() io.Copy(io.Discard, r.Body) }) ts := newClientServerTest(t, mode, mux).ts timeout := 100 * time.Millisecond c := ts.Client() c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { conn, err := net.Dial(n, addr) if err != nil { return nil, err } conn.SetDeadline(time.Now().Add(timeout)) if debug { conn = NewLoggingConn("client", conn) } return conn, nil } getFailed := false nRuns := 5 if testing.Short() { nRuns = 1 } for i := 0; i < nRuns; i++ { if debug { println("run", i+1, "of", nRuns) } sres, err := c.Get(ts.URL + "/get") if err != nil { if !getFailed { // Make the timeout longer, once. getFailed = true t.Logf("increasing timeout") i-- timeout *= 10 continue } t.Errorf("Error issuing GET: %v", err) break } req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body) _, err = c.Do(req) if err == nil { sres.Body.Close() t.Errorf("Unexpected successful PUT") break } sres.Body.Close() } if debug { println("tests complete; waiting for handlers to finish") } ts.Close() } func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) } func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping timeout test in -short mode") } timeout := 2 * time.Millisecond retry := true for retry && !t.Failed() { var srvWG sync.WaitGroup inHandler := make(chan bool, 1) mux := NewServeMux() mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) { inHandler <- true srvWG.Done() }) mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { inHandler <- true <-r.Context().Done() srvWG.Done() }) ts := newClientServerTest(t, mode, mux).ts c := ts.Client() c.Transport.(*Transport).ResponseHeaderTimeout = timeout retry = false srvWG.Add(3) tests := []struct { path string wantTimeout bool }{ {path: "/fast"}, {path: "/slow", wantTimeout: true}, {path: "/fast"}, } for i, tt := range tests { req, _ := NewRequest("GET", ts.URL+tt.path, nil) req = req.WithT(t) res, err := c.Do(req) <-inHandler if err != nil { uerr, ok := err.(*url.Error) if !ok { t.Errorf("error is not a url.Error; got: %#v", err) continue } nerr, ok := uerr.Err.(net.Error) if !ok { t.Errorf("error does not satisfy net.Error interface; got: %#v", err) continue } if !nerr.Timeout() { t.Errorf("want timeout error; got: %q", nerr) continue } if !tt.wantTimeout { if !retry { // The timeout may be set too short. Retry with a longer one. t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout) timeout *= 2 retry = true } } if !strings.Contains(err.Error(), "timeout awaiting response headers") { t.Errorf("%d. unexpected error: %v", i, err) } continue } if tt.wantTimeout { t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path) continue } if res.StatusCode != 200 { t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode) } } srvWG.Wait() ts.Close() } } func TestTransportCancelRequest(t *testing.T) { run(t, testTransportCancelRequest, []testMode{http1Mode}) } func testTransportCancelRequest(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping test in -short mode") } const msg = "Hello" unblockc := make(chan bool) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, msg) w.(Flusher).Flush() // send headers and some body <-unblockc })).ts defer close(unblockc) c := ts.Client() tr := c.Transport.(*Transport) req, _ := NewRequest("GET", ts.URL, nil) res, err := c.Do(req) if err != nil { t.Fatal(err) } body := make([]byte, len(msg)) n, _ := io.ReadFull(res.Body, body) if n != len(body) || !bytes.Equal(body, []byte(msg)) { t.Errorf("Body = %q; want %q", body[:n], msg) } tr.CancelRequest(req) tail, err := io.ReadAll(res.Body) res.Body.Close() if err != ExportErrRequestCanceled { t.Errorf("Body.Read error = %v; want errRequestCanceled", err) } else if len(tail) > 0 { t.Errorf("Spurious bytes from Body.Read: %q", tail) } // Verify no outstanding requests after readLoop/writeLoop // goroutines shut down. waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { n := tr.NumPendingRequestsForTesting() if n > 0 { if d > 0 { t.Logf("pending requests = %d after %v (want 0)", n, d) } return false } return true }) } func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) { if testing.Short() { t.Skip("skipping test in -short mode") } unblockc := make(chan bool) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockc })).ts defer close(unblockc) c := ts.Client() tr := c.Transport.(*Transport) donec := make(chan bool) req, _ := NewRequest("GET", ts.URL, body) go func() { defer close(donec) c.Do(req) }() unblockc <- true waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { tr.CancelRequest(req) select { case <-donec: return true default: if d > 0 { t.Logf("Do of canceled request has not returned after %v", d) } return false } }) } func TestTransportCancelRequestInDo(t *testing.T) { run(t, func(t *testing.T, mode testMode) { testTransportCancelRequestInDo(t, mode, nil) }, []testMode{http1Mode}) } func TestTransportCancelRequestWithBodyInDo(t *testing.T) { run(t, func(t *testing.T, mode testMode) { testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0})) }, []testMode{http1Mode}) } func TestTransportCancelRequestInDial(t *testing.T) { defer afterTest(t) if testing.Short() { t.Skip("skipping test in -short mode") } var logbuf strings.Builder eventLog := log.New(&logbuf, "", 0) unblockDial := make(chan bool) defer close(unblockDial) inDial := make(chan bool) tr := &Transport{ Dial: func(network, addr string) (net.Conn, error) { eventLog.Println("dial: blocking") if !<-inDial { return nil, errors.New("main Test goroutine exited") } <-unblockDial return nil, errors.New("nope") }, } cl := &Client{Transport: tr} gotres := make(chan bool) req, _ := NewRequest("GET", "http://something.no-network.tld/", nil) go func() { _, err := cl.Do(req) eventLog.Printf("Get = %v", err) gotres <- true }() inDial <- true eventLog.Printf("canceling") tr.CancelRequest(req) tr.CancelRequest(req) // used to panic on second call if d, ok := t.Deadline(); ok { // When the test's deadline is about to expire, log the pending events for // better debugging. timeout := time.Until(d) * 19 / 20 // Allow 5% for cleanup. timer := time.AfterFunc(timeout, func() { panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String())) }) defer timer.Stop() } <-gotres got := logbuf.String() want := `dial: blocking canceling Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection ` if got != want { t.Errorf("Got events:\n%s\nWant:\n%s", got, want) } } func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) } func testCancelRequestWithChannel(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping test in -short mode") } const msg = "Hello" unblockc := make(chan struct{}) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, msg) w.(Flusher).Flush() // send headers and some body <-unblockc })).ts defer close(unblockc) c := ts.Client() tr := c.Transport.(*Transport) req, _ := NewRequest("GET", ts.URL, nil) cancel := make(chan struct{}) req.Cancel = cancel res, err := c.Do(req) if err != nil { t.Fatal(err) } body := make([]byte, len(msg)) n, _ := io.ReadFull(res.Body, body) if n != len(body) || !bytes.Equal(body, []byte(msg)) { t.Errorf("Body = %q; want %q", body[:n], msg) } close(cancel) tail, err := io.ReadAll(res.Body) res.Body.Close() if err != ExportErrRequestCanceled { t.Errorf("Body.Read error = %v; want errRequestCanceled", err) } else if len(tail) > 0 { t.Errorf("Spurious bytes from Body.Read: %q", tail) } // Verify no outstanding requests after readLoop/writeLoop // goroutines shut down. waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { n := tr.NumPendingRequestsForTesting() if n > 0 { if d > 0 { t.Logf("pending requests = %d after %v (want 0)", n, d) } return false } return true }) } // Issue 51354 func TestCancelRequestWithBodyWithChannel(t *testing.T) { run(t, testCancelRequestWithBodyWithChannel, []testMode{http1Mode}) } func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping test in -short mode") } const msg = "Hello" unblockc := make(chan struct{}) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, msg) w.(Flusher).Flush() // send headers and some body <-unblockc })).ts defer close(unblockc) c := ts.Client() tr := c.Transport.(*Transport) req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody")) cancel := make(chan struct{}) req.Cancel = cancel res, err := c.Do(req) if err != nil { t.Fatal(err) } body := make([]byte, len(msg)) n, _ := io.ReadFull(res.Body, body) if n != len(body) || !bytes.Equal(body, []byte(msg)) { t.Errorf("Body = %q; want %q", body[:n], msg) } close(cancel) tail, err := io.ReadAll(res.Body) res.Body.Close() if err != ExportErrRequestCanceled { t.Errorf("Body.Read error = %v; want errRequestCanceled", err) } else if len(tail) > 0 { t.Errorf("Spurious bytes from Body.Read: %q", tail) } // Verify no outstanding requests after readLoop/writeLoop // goroutines shut down. waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { n := tr.NumPendingRequestsForTesting() if n > 0 { if d > 0 { t.Logf("pending requests = %d after %v (want 0)", n, d) } return false } return true }) } func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) { run(t, func(t *testing.T, mode testMode) { testCancelRequestWithChannelBeforeDo(t, mode, false) }) } func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) { run(t, func(t *testing.T, mode testMode) { testCancelRequestWithChannelBeforeDo(t, mode, true) }) } func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) { unblockc := make(chan bool) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockc })).ts defer close(unblockc) c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) if withCtx { ctx, cancel := context.WithCancel(context.Background()) cancel() req = req.WithContext(ctx) } else { ch := make(chan struct{}) req.Cancel = ch close(ch) } _, err := c.Do(req) if ue, ok := err.(*url.Error); ok { err = ue.Err } if withCtx { if err != context.Canceled { t.Errorf("Do error = %v; want %v", err, context.Canceled) } } else { if err == nil || !strings.Contains(err.Error(), "canceled") { t.Errorf("Do error = %v; want cancellation", err) } } } // Issue 11020. The returned error message should be errRequestCanceled func TestTransportCancelBeforeResponseHeaders(t *testing.T) { defer afterTest(t) serverConnCh := make(chan net.Conn, 1) tr := &Transport{ Dial: func(network, addr string) (net.Conn, error) { cc, sc := net.Pipe() serverConnCh <- sc return cc, nil }, } defer tr.CloseIdleConnections() errc := make(chan error, 1) req, _ := NewRequest("GET", "http://example.com/", nil) go func() { _, err := tr.RoundTrip(req) errc <- err }() sc := <-serverConnCh verb := make([]byte, 3) if _, err := io.ReadFull(sc, verb); err != nil { t.Errorf("Error reading HTTP verb from server: %v", err) } if string(verb) != "GET" { t.Errorf("server received %q; want GET", verb) } defer sc.Close() tr.CancelRequest(req) err := <-errc if err == nil { t.Fatalf("unexpected success from RoundTrip") } if err != ExportErrRequestCanceled { t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err) } } // golang.org/issue/3672 -- Client can't close HTTP stream // Calling Close on a Response.Body used to just read until EOF. // Now it actually closes the TCP connection. func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) } func testTransportCloseResponseBody(t *testing.T, mode testMode) { writeErr := make(chan error, 1) msg := []byte("young\n") ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { for { _, err := w.Write(msg) if err != nil { writeErr <- err return } w.(Flusher).Flush() } })).ts c := ts.Client() tr := c.Transport.(*Transport) req, _ := NewRequest("GET", ts.URL, nil) defer tr.CancelRequest(req) res, err := c.Do(req) if err != nil { t.Fatal(err) } const repeats = 3 buf := make([]byte, len(msg)*repeats) want := bytes.Repeat(msg, repeats) _, err = io.ReadFull(res.Body, buf) if err != nil { t.Fatal(err) } if !bytes.Equal(buf, want) { t.Fatalf("read %q; want %q", buf, want) } if err := res.Body.Close(); err != nil { t.Errorf("Close = %v", err) } if err := <-writeErr; err == nil { t.Errorf("expected non-nil write error") } } type fooProto struct{} func (fooProto) RoundTrip(req *Request) (*Response, error) { res := &Response{ Status: "200 OK", StatusCode: 200, Header: make(Header), Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())), } return res, nil } func TestTransportAltProto(t *testing.T) { defer afterTest(t) tr := &Transport{} c := &Client{Transport: tr} tr.RegisterProtocol("foo", fooProto{}) res, err := c.Get("foo://bar.com/path") if err != nil { t.Fatal(err) } bodyb, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } body := string(bodyb) if e := "You wanted foo://bar.com/path"; body != e { t.Errorf("got response %q, want %q", body, e) } } func TestTransportNoHost(t *testing.T) { defer afterTest(t) tr := &Transport{} _, err := tr.RoundTrip(&Request{ Header: make(Header), URL: &url.URL{ Scheme: "http", }, }) want := "http: no Host in request URL" if got := fmt.Sprint(err); got != want { t.Errorf("error = %v; want %q", err, want) } } // Issue 13311 func TestTransportEmptyMethod(t *testing.T) { req, _ := NewRequest("GET", "http://foo.com/", nil) req.Method = "" // docs say "For client requests an empty string means GET" got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport if err != nil { t.Fatal(err) } if !strings.Contains(string(got), "GET ") { t.Fatalf("expected substring 'GET '; got: %s", got) } } func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) } func testTransportSocketLateBinding(t *testing.T, mode testMode) { mux := NewServeMux() fooGate := make(chan bool, 1) mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) { w.Header().Set("foo-ipport", r.RemoteAddr) w.(Flusher).Flush() <-fooGate }) mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) { w.Header().Set("bar-ipport", r.RemoteAddr) }) ts := newClientServerTest(t, mode, mux).ts dialGate := make(chan bool, 1) dialing := make(chan bool) c := ts.Client() c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { for { select { case ok := <-dialGate: if !ok { return nil, errors.New("manually closed") } return net.Dial(n, addr) case dialing <- true: } } } defer close(dialGate) dialGate <- true // only allow one dial fooRes, err := c.Get(ts.URL + "/foo") if err != nil { t.Fatal(err) } fooAddr := fooRes.Header.Get("foo-ipport") if fooAddr == "" { t.Fatal("No addr on /foo request") } fooDone := make(chan struct{}) go func() { // We know that the foo Dial completed and reached the handler because we // read its header. Wait for the bar request to block in Dial, then // let the foo response finish so we can use its connection for /bar. if mode == http2Mode { // In HTTP/2 mode, the second Dial won't happen because the protocol // multiplexes the streams by default. Just sleep for an arbitrary time; // the test should pass regardless of how far the bar request gets by this // point. select { case <-dialing: t.Errorf("unexpected second Dial in HTTP/2 mode") case <-time.After(10 * time.Millisecond): } } else { <-dialing } fooGate <- true io.Copy(io.Discard, fooRes.Body) fooRes.Body.Close() close(fooDone) }() defer func() { <-fooDone }() barRes, err := c.Get(ts.URL + "/bar") if err != nil { t.Fatal(err) } barAddr := barRes.Header.Get("bar-ipport") if barAddr != fooAddr { t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr) } barRes.Body.Close() } // Issue 2184 func TestTransportReading100Continue(t *testing.T) { defer afterTest(t) const numReqs = 5 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) } reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) } send100Response := func(w *io.PipeWriter, r *io.PipeReader) { defer w.Close() defer r.Close() br := bufio.NewReader(r) n := 0 for { n++ req, err := ReadRequest(br) if err == io.EOF { return } if err != nil { t.Error(err) return } slurp, err := io.ReadAll(req.Body) if err != nil { t.Errorf("Server request body slurp: %v", err) return } id := req.Header.Get("Request-Id") resCode := req.Header.Get("X-Want-Response-Code") if resCode == "" { resCode = "100 Continue" if string(slurp) != reqBody(n) { t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n)) } } body := fmt.Sprintf("Response number %d", n) v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s Date: Thu, 28 Feb 2013 17:55:41 GMT HTTP/1.1 200 OK Content-Type: text/html Echo-Request-Id: %s Content-Length: %d %s`, resCode, id, len(body), body), "\n", "\r\n", -1)) w.Write(v) if id == reqID(numReqs) { return } } } tr := &Transport{ Dial: func(n, addr string) (net.Conn, error) { sr, sw := io.Pipe() // server read/write cr, cw := io.Pipe() // client read/write conn := &rwTestConn{ Reader: cr, Writer: sw, closeFunc: func() error { sw.Close() cw.Close() return nil }, } go send100Response(cw, sr) return conn, nil }, DisableKeepAlives: false, } defer tr.CloseIdleConnections() c := &Client{Transport: tr} testResponse := func(req *Request, name string, wantCode int) { t.Helper() res, err := c.Do(req) if err != nil { t.Fatalf("%s: Do: %v", name, err) } if res.StatusCode != wantCode { t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode) } if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack { t.Errorf("%s: response id %q != request id %q", name, idBack, id) } _, err = io.ReadAll(res.Body) if err != nil { t.Fatalf("%s: Slurp error: %v", name, err) } } // Few 100 responses, making sure we're not off-by-one. for i := 1; i <= numReqs; i++ { req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i))) req.Header.Set("Request-Id", reqID(i)) testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200) } } // Issue 17739: the HTTP client must ignore any unknown 1xx // informational responses before the actual response. func TestTransportIgnore1xxResponses(t *testing.T) { run(t, testTransportIgnore1xxResponses, []testMode{http1Mode}) } func testTransportIgnore1xxResponses(t *testing.T, mode testMode) { cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello")) buf.Flush() conn.Close() })) cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway var got strings.Builder req, _ := NewRequest("GET", cst.ts.URL, nil) req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ Got1xxResponse: func(code int, header textproto.MIMEHeader) error { fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header) return nil }, })) res, err := cst.c.Do(req) if err != nil { t.Fatal(err) } defer res.Body.Close() res.Write(&got) want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello" if got.String() != want { t.Errorf(" got: %q\nwant: %q\n", got.String(), want) } } func TestTransportLimits1xxResponses(t *testing.T) { run(t, testTransportLimits1xxResponses, []testMode{http1Mode}) } func testTransportLimits1xxResponses(t *testing.T, mode testMode) { cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() for i := 0; i < 10; i++ { buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n")) } buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) buf.Flush() conn.Close() })) cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway res, err := cst.c.Get(cst.ts.URL) if res != nil { defer res.Body.Close() } got := fmt.Sprint(err) wantSub := "too many 1xx informational responses" if !strings.Contains(got, wantSub) { t.Errorf("Get error = %v; want substring %q", err, wantSub) } } // Issue 26161: the HTTP client must treat 101 responses // as the final response. func TestTransportTreat101Terminal(t *testing.T) { run(t, testTransportTreat101Terminal, []testMode{http1Mode}) } func testTransportTreat101Terminal(t *testing.T, mode testMode) { cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n")) buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) buf.Flush() conn.Close() })) res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } defer res.Body.Close() if res.StatusCode != StatusSwitchingProtocols { t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode) } } type proxyFromEnvTest struct { req string // URL to fetch; blank means "http://example.com" env string // HTTP_PROXY httpsenv string // HTTPS_PROXY noenv string // NO_PROXY reqmeth string // REQUEST_METHOD want string wanterr error } func (t proxyFromEnvTest) String() string { var buf strings.Builder space := func() { if buf.Len() > 0 { buf.WriteByte(' ') } } if t.env != "" { fmt.Fprintf(&buf, "http_proxy=%q", t.env) } if t.httpsenv != "" { space() fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv) } if t.noenv != "" { space() fmt.Fprintf(&buf, "no_proxy=%q", t.noenv) } if t.reqmeth != "" { space() fmt.Fprintf(&buf, "request_method=%q", t.reqmeth) } req := "http://example.com" if t.req != "" { req = t.req } space() fmt.Fprintf(&buf, "req=%q", req) return strings.TrimSpace(buf.String()) } var proxyFromEnvTests = []proxyFromEnvTest{ {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"}, {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"}, {env: "cache.corp.example.com", want: "http://cache.corp.example.com"}, {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"}, {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"}, {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"}, {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"}, // Don't use secure for http {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"}, // Use secure for https. {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"}, {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"}, // Issue 16405: don't use HTTP_PROXY in a CGI environment, // where HTTP_PROXY can be attacker-controlled. {env: "http://10.1.2.3:8080", reqmeth: "POST", want: "", wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")}, {want: ""}, {noenv: "example.com", req: "http://example.com/", env: "proxy", want: ""}, {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: ""}, {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, } func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) { t.Helper() reqURL := tt.req if reqURL == "" { reqURL = "http://example.com" } req, _ := NewRequest("GET", reqURL, nil) url, err := proxyForRequest(req) if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e { t.Errorf("%v: got error = %q, want %q", tt, g, e) return } if got := fmt.Sprintf("%s", url); got != tt.want { t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want) } } func TestProxyFromEnvironment(t *testing.T) { ResetProxyEnv() defer ResetProxyEnv() for _, tt := range proxyFromEnvTests { testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) { os.Setenv("HTTP_PROXY", tt.env) os.Setenv("HTTPS_PROXY", tt.httpsenv) os.Setenv("NO_PROXY", tt.noenv) os.Setenv("REQUEST_METHOD", tt.reqmeth) ResetCachedEnvironment() return ProxyFromEnvironment(req) }) } } func TestProxyFromEnvironmentLowerCase(t *testing.T) { ResetProxyEnv() defer ResetProxyEnv() for _, tt := range proxyFromEnvTests { testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) { os.Setenv("http_proxy", tt.env) os.Setenv("https_proxy", tt.httpsenv) os.Setenv("no_proxy", tt.noenv) os.Setenv("REQUEST_METHOD", tt.reqmeth) ResetCachedEnvironment() return ProxyFromEnvironment(req) }) } } func TestIdleConnChannelLeak(t *testing.T) { run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel) } func testIdleConnChannelLeak(t *testing.T, mode testMode) { // Not parallel: uses global test hooks. var mu sync.Mutex var n int ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() n++ mu.Unlock() })).ts const nReqs = 5 didRead := make(chan bool, nReqs) SetReadLoopBeforeNextReadHook(func() { didRead <- true }) defer SetReadLoopBeforeNextReadHook(nil) c := ts.Client() tr := c.Transport.(*Transport) tr.Dial = func(netw, addr string) (net.Conn, error) { return net.Dial(netw, ts.Listener.Addr().String()) } // First, without keep-alives. for _, disableKeep := range []bool{true, false} { tr.DisableKeepAlives = disableKeep for i := 0; i < nReqs; i++ { _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i)) if err != nil { t.Fatal(err) } // Note: no res.Body.Close is needed here, since the // response Content-Length is zero. Perhaps the test // should be more explicit and use a HEAD, but tests // elsewhere guarantee that zero byte responses generate // a "Content-Length: 0" instead of chunking. } // At this point, each of the 5 Transport.readLoop goroutines // are scheduling noting that there are no response bodies (see // earlier comment), and are then calling putIdleConn, which // decrements this count. Usually that happens quickly, which is // why this test has seemed to work for ages. But it's still // racey: we have wait for them to finish first. See Issue 10427 for i := 0; i < nReqs; i++ { <-didRead } if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 { t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got) } } } // Verify the status quo: that the Client.Post function coerces its // body into a ReadCloser if it's a Closer, and that the Transport // then closes it. func TestTransportClosesRequestBody(t *testing.T) { run(t, testTransportClosesRequestBody, []testMode{http1Mode}) } func testTransportClosesRequestBody(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.Copy(io.Discard, r.Body) })).ts c := ts.Client() closes := 0 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) if err != nil { t.Fatal(err) } res.Body.Close() if closes != 1 { t.Errorf("closes = %d; want 1", closes) } } func TestTransportTLSHandshakeTimeout(t *testing.T) { defer afterTest(t) if testing.Short() { t.Skip("skipping in short mode") } ln := newLocalListener(t) defer ln.Close() testdonec := make(chan struct{}) defer close(testdonec) go func() { c, err := ln.Accept() if err != nil { t.Error(err) return } <-testdonec c.Close() }() tr := &Transport{ Dial: func(_, _ string) (net.Conn, error) { return net.Dial("tcp", ln.Addr().String()) }, TLSHandshakeTimeout: 250 * time.Millisecond, } cl := &Client{Transport: tr} _, err := cl.Get("https://dummy.tld/") if err == nil { t.Error("expected error") return } ue, ok := err.(*url.Error) if !ok { t.Errorf("expected url.Error; got %#v", err) return } ne, ok := ue.Err.(net.Error) if !ok { t.Errorf("expected net.Error; got %#v", err) return } if !ne.Timeout() { t.Errorf("expected timeout error; got %v", err) } if !strings.Contains(err.Error(), "handshake timeout") { t.Errorf("expected 'handshake timeout' in error; got %v", err) } } // Trying to repro golang.org/issue/3514 func TestTLSServerClosesConnection(t *testing.T) { run(t, testTLSServerClosesConnection, []testMode{https1Mode}) } func testTLSServerClosesConnection(t *testing.T, mode testMode) { closedc := make(chan bool, 1) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if strings.Contains(r.URL.Path, "/keep-alive-then-die") { conn, _, _ := w.(Hijacker).Hijack() conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) conn.Close() closedc <- true return } fmt.Fprintf(w, "hello") })).ts c := ts.Client() tr := c.Transport.(*Transport) var nSuccess = 0 var errs []error const trials = 20 for i := 0; i < trials; i++ { tr.CloseIdleConnections() res, err := c.Get(ts.URL + "/keep-alive-then-die") if err != nil { t.Fatal(err) } <-closedc slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } if string(slurp) != "foo" { t.Errorf("Got %q, want foo", slurp) } // Now try again and see if we successfully // pick a new connection. res, err = c.Get(ts.URL + "/") if err != nil { errs = append(errs, err) continue } slurp, err = io.ReadAll(res.Body) if err != nil { errs = append(errs, err) continue } nSuccess++ } if nSuccess > 0 { t.Logf("successes = %d of %d", nSuccess, trials) } else { t.Errorf("All runs failed:") } for _, err := range errs { t.Logf(" err: %v", err) } } // byteFromChanReader is an io.Reader that reads a single byte at a // time from the channel. When the channel is closed, the reader // returns io.EOF. type byteFromChanReader chan byte func (c byteFromChanReader) Read(p []byte) (n int, err error) { if len(p) == 0 { return } b, ok := <-c if !ok { return 0, io.EOF } p[0] = b return 1, nil } // Verifies that the Transport doesn't reuse a connection in the case // where the server replies before the request has been fully // written. We still honor that reply (see TestIssue3595), but don't // send future requests on the connection because it's then in a // questionable state. // golang.org/issue/7569 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel) } func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) { defer func(d time.Duration) { *MaxWriteWaitBeforeConnReuse = d }(*MaxWriteWaitBeforeConnReuse) *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond var sconn struct { sync.Mutex c net.Conn } var getOkay bool var copying sync.WaitGroup closeConn := func() { sconn.Lock() defer sconn.Unlock() if sconn.c != nil { sconn.c.Close() sconn.c = nil if !getOkay { t.Logf("Closed server connection") } } } defer func() { closeConn() copying.Wait() }() ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method == "GET" { io.WriteString(w, "bar") return } conn, _, _ := w.(Hijacker).Hijack() sconn.Lock() sconn.c = conn sconn.Unlock() conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive copying.Add(1) go func() { io.Copy(io.Discard, conn) copying.Done() }() })).ts c := ts.Client() const bodySize = 256 << 10 finalBit := make(byteFromChanReader, 1) req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit)) req.ContentLength = bodySize res, err := c.Do(req) if err := wantBody(res, err, "foo"); err != nil { t.Errorf("POST response: %v", err) } res, err = c.Get(ts.URL) if err := wantBody(res, err, "bar"); err != nil { t.Errorf("GET response: %v", err) return } getOkay = true // suppress test noise finalBit <- 'x' // unblock the writeloop of the first Post close(finalBit) } // Tests that we don't leak Transport persistConn.readLoop goroutines // when a server hangs up immediately after saying it would keep-alive. func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) } func testTransportIssue10457(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // Send a response with no body, keep-alive // (implicit), and then lie and immediately close the // connection. This forces the Transport's readLoop to // immediately Peek an io.EOF and get to the point // that used to hang. conn, _, _ := w.(Hijacker).Hijack() conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive conn.Close() })).ts c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatalf("Get: %v", err) } defer res.Body.Close() // Just a sanity check that we at least get the response. The real // test here is that the "defer afterTest" above doesn't find any // leaked goroutines. if got, want := res.Header.Get("Foo"), "Bar"; got != want { t.Errorf("Foo header = %q; want %q", got, want) } } type closerFunc func() error func (f closerFunc) Close() error { return f() } type writerFuncConn struct { net.Conn write func(p []byte) (n int, err error) } func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) } // Issues 4677, 18241, and 17844. If we try to reuse a connection that the // server is in the process of closing, we may end up successfully writing out // our request (or a portion of our request) only to find a connection error // when we try to read from (or finish writing to) the socket. // // NOTE: we resend a request only if: // - we reused a keep-alive connection // - we haven't yet received any header data // - either we wrote no bytes to the server, or the request is idempotent // // This automatically prevents an infinite resend loop because we'll run out of // the cached keep-alive connections eventually. func TestRetryRequestsOnError(t *testing.T) { run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode}) } func testRetryRequestsOnError(t *testing.T, mode testMode) { newRequest := func(method, urlStr string, body io.Reader) *Request { req, err := NewRequest(method, urlStr, body) if err != nil { t.Fatal(err) } return req } testCases := []struct { name string failureN int failureErr error // Note that we can't just re-use the Request object across calls to c.Do // because we need to rewind Body between calls. (GetBody is only used to // rewind Body on failure and redirects, not just because it's done.) req func() *Request reqString string }{ { name: "IdempotentNoBodySomeWritten", // Believe that we've written some bytes to the server, so we know we're // not just in the "retry when no bytes sent" case". failureN: 1, // Use the specific error that shouldRetryRequest looks for with idempotent requests. failureErr: ExportErrServerClosedIdle, req: func() *Request { return newRequest("GET", "http://fake.golang", nil) }, reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`, }, { name: "IdempotentGetBodySomeWritten", // Believe that we've written some bytes to the server, so we know we're // not just in the "retry when no bytes sent" case". failureN: 1, // Use the specific error that shouldRetryRequest looks for with idempotent requests. failureErr: ExportErrServerClosedIdle, req: func() *Request { return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n")) }, reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`, }, { name: "NothingWrittenNoBody", // It's key that we return 0 here -- that's what enables Transport to know // that nothing was written, even though this is a non-idempotent request. failureN: 0, failureErr: errors.New("second write fails"), req: func() *Request { return newRequest("DELETE", "http://fake.golang", nil) }, reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`, }, { name: "NothingWrittenGetBody", // It's key that we return 0 here -- that's what enables Transport to know // that nothing was written, even though this is a non-idempotent request. failureN: 0, failureErr: errors.New("second write fails"), // Note that NewRequest will set up GetBody for strings.Reader, which is // required for the retry to occur req: func() *Request { return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n")) }, reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var ( mu sync.Mutex logbuf strings.Builder ) logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&logbuf, format, args...) logbuf.WriteByte('\n') } ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { logf("Handler") w.Header().Set("X-Status", "ok") })).ts var writeNumAtomic int32 c := ts.Client() c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) { logf("Dial") c, err := net.Dial(network, ts.Listener.Addr().String()) if err != nil { logf("Dial error: %v", err) return nil, err } return &writerFuncConn{ Conn: c, write: func(p []byte) (n int, err error) { if atomic.AddInt32(&writeNumAtomic, 1) == 2 { logf("intentional write failure") return tc.failureN, tc.failureErr } logf("Write(%q)", p) return c.Write(p) }, }, nil } SetRoundTripRetried(func() { logf("Retried.") }) defer SetRoundTripRetried(nil) for i := 0; i < 3; i++ { t0 := time.Now() req := tc.req() res, err := c.Do(req) if err != nil { if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 { mu.Lock() got := logbuf.String() mu.Unlock() t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got) } t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse) } res.Body.Close() if res.Request != req { t.Errorf("Response.Request != original request; want identical Request") } } mu.Lock() got := logbuf.String() mu.Unlock() want := fmt.Sprintf(`Dial Write("%s") Handler intentional write failure Retried. Dial Write("%s") Handler Write("%s") Handler `, tc.reqString, tc.reqString, tc.reqString) if got != want { t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want) } }) } } // Issue 6981 func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) } func testTransportClosesBodyOnError(t *testing.T, mode testMode) { readBody := make(chan error, 1) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := io.ReadAll(r.Body) readBody <- err })).ts c := ts.Client() fakeErr := errors.New("fake error") didClose := make(chan bool, 1) req, _ := NewRequest("POST", ts.URL, struct { io.Reader io.Closer }{ io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)), closerFunc(func() error { select { case didClose <- true: default: } return nil }), }) res, err := c.Do(req) if res != nil { defer res.Body.Close() } if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) { t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error()) } if err := <-readBody; err == nil { t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'") } select { case <-didClose: default: t.Errorf("didn't see Body.Close") } } func TestTransportDialTLS(t *testing.T) { run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode}) } func testTransportDialTLS(t *testing.T, mode testMode) { var mu sync.Mutex // guards following var gotReq, didDial bool ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() gotReq = true mu.Unlock() })).ts c := ts.Client() c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) { mu.Lock() didDial = true mu.Unlock() c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) if err != nil { return nil, err } return c, c.Handshake() } res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } res.Body.Close() mu.Lock() if !gotReq { t.Error("didn't get request") } if !didDial { t.Error("didn't use dial hook") } } func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) } func testTransportDialContext(t *testing.T, mode testMode) { var mu sync.Mutex // guards following var gotReq bool var receivedContext context.Context ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() gotReq = true mu.Unlock() })).ts c := ts.Client() c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { mu.Lock() receivedContext = ctx mu.Unlock() return net.Dial(netw, addr) } req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) } ctx := context.WithValue(context.Background(), "some-key", "some-value") res, err := c.Do(req.WithContext(ctx)) if err != nil { t.Fatal(err) } res.Body.Close() mu.Lock() if !gotReq { t.Error("didn't get request") } if receivedContext != ctx { t.Error("didn't receive correct context") } } func TestTransportDialTLSContext(t *testing.T) { run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode}) } func testTransportDialTLSContext(t *testing.T, mode testMode) { var mu sync.Mutex // guards following var gotReq bool var receivedContext context.Context ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() gotReq = true mu.Unlock() })).ts c := ts.Client() c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { mu.Lock() receivedContext = ctx mu.Unlock() c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) if err != nil { return nil, err } return c, c.HandshakeContext(ctx) } req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) } ctx := context.WithValue(context.Background(), "some-key", "some-value") res, err := c.Do(req.WithContext(ctx)) if err != nil { t.Fatal(err) } res.Body.Close() mu.Lock() if !gotReq { t.Error("didn't get request") } if receivedContext != ctx { t.Error("didn't receive correct context") } } // Test for issue 8755 // Ensure that if a proxy returns an error, it is exposed by RoundTrip func TestRoundTripReturnsProxyError(t *testing.T) { badProxy := func(*Request) (*url.URL, error) { return nil, errors.New("errorMessage") } tr := &Transport{Proxy: badProxy} req, _ := NewRequest("GET", "http://example.com", nil) _, err := tr.RoundTrip(req) if err == nil { t.Error("Expected proxy error to be returned by RoundTrip") } } // tests that putting an idle conn after a call to CloseIdleConns does return it func TestTransportCloseIdleConnsThenReturn(t *testing.T) { tr := &Transport{} wantIdle := func(when string, n int) bool { got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn if got == n { return true } t.Errorf("%s: idle conns = %d; want %d", when, got, n) return false } wantIdle("start", 0) if !tr.PutIdleTestConn("http", "example.com") { t.Fatal("put failed") } if !tr.PutIdleTestConn("http", "example.com") { t.Fatal("second put failed") } wantIdle("after put", 2) tr.CloseIdleConnections() if !tr.IsIdleForTesting() { t.Error("should be idle after CloseIdleConnections") } wantIdle("after close idle", 0) if tr.PutIdleTestConn("http", "example.com") { t.Fatal("put didn't fail") } wantIdle("after second put", 0) tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode if tr.IsIdleForTesting() { t.Error("shouldn't be idle after QueueForIdleConnForTesting") } if !tr.PutIdleTestConn("http", "example.com") { t.Fatal("after re-activation") } wantIdle("after final put", 1) } // Test for issue 34282 // Ensure that getConn doesn't call the GotConn trace hook on an HTTP/2 idle conn func TestTransportTraceGotConnH2IdleConns(t *testing.T) { tr := &Transport{} wantIdle := func(when string, n int) bool { got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2 if got == n { return true } t.Errorf("%s: idle conns = %d; want %d", when, got, n) return false } wantIdle("start", 0) alt := funcRoundTripper(func() {}) if !tr.PutIdleTestConnH2("https", "example.com:443", alt) { t.Fatal("put failed") } wantIdle("after put", 1) ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ GotConn: func(httptrace.GotConnInfo) { // tr.getConn should leave it for the HTTP/2 alt to call GotConn. t.Error("GotConn called") }, }) req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil) _, err := tr.RoundTrip(req) if err != errFakeRoundTrip { t.Errorf("got error: %v; want %q", err, errFakeRoundTrip) } wantIdle("after round trip", 1) } func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode}) } func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } timeout := 1 * time.Millisecond retry := true for retry { trFunc := func(tr *Transport) { tr.MaxConnsPerHost = 1 tr.MaxIdleConnsPerHost = 1 tr.IdleConnTimeout = timeout } cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc) retry = false tooShort := func(err error) bool { if err == nil || !strings.Contains(err.Error(), "use of closed network connection") { return false } if !retry { t.Helper() t.Logf("idle conn timeout %v may be too short; retrying with longer", timeout) timeout *= 2 retry = true cst.close() } return true } if _, err := cst.c.Get(cst.ts.URL); err != nil { if tooShort(err) { continue } t.Fatalf("got error: %s", err) } time.Sleep(10 * timeout) if _, err := cst.c.Get(cst.ts.URL); err != nil { if tooShort(err) { continue } t.Fatalf("got error: %s", err) } } } // This tests that a client requesting a content range won't also // implicitly ask for gzip support. If they want that, they need to do it // on their own. // golang.org/issue/8923 func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) } func testTransportRangeAndGzip(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { t.Error("Transport advertised gzip support in the Accept header") } if r.Header.Get("Range") == "" { t.Error("no Range in request") } })).ts c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) req.Header.Set("Range", "bytes=7-11") res, err := c.Do(req) if err != nil { t.Fatal(err) } res.Body.Close() } // Test for issue 10474 func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) } func testTransportResponseCancelRace(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // important that this response has a body. var b [1024]byte w.Write(b[:]) })).ts tr := ts.Client().Transport.(*Transport) req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) } res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } // If we do an early close, Transport just throws the connection away and // doesn't reuse it. In order to trigger the bug, it has to reuse the connection // so read the body if _, err := io.Copy(io.Discard, res.Body); err != nil { t.Fatal(err) } req2, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) } tr.CancelRequest(req) res, err = tr.RoundTrip(req2) if err != nil { t.Fatal(err) } res.Body.Close() } // Test for issue 19248: Content-Encoding's value is case insensitive. func TestTransportContentEncodingCaseInsensitive(t *testing.T) { run(t, testTransportContentEncodingCaseInsensitive) } func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) { for _, ce := range []string{"gzip", "GZIP"} { ce := ce t.Run(ce, func(t *testing.T) { const encodedString = "Hello Gopher" ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", ce) gz := gzip.NewWriter(w) gz.Write([]byte(encodedString)) gz.Close() })).ts res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) } if string(body) != encodedString { t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body)) } }) } } func TestTransportDialCancelRace(t *testing.T) { run(t, testTransportDialCancelRace, testNotParallel, []testMode{http1Mode}) } func testTransportDialCancelRace(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts tr := ts.Client().Transport.(*Transport) req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) } SetEnterRoundTripHook(func() { tr.CancelRequest(req) }) defer SetEnterRoundTripHook(nil) res, err := tr.RoundTrip(req) if err != ExportErrRequestCanceled { t.Errorf("expected canceled request error; got %v", err) if err == nil { res.Body.Close() } } } // https://go.dev/issue/49621 func TestConnClosedBeforeRequestIsWritten(t *testing.T) { run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode}) } func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), func(tr *Transport) { tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) { // Connection immediately returns errors. return &funcConn{ read: func([]byte) (int, error) { return 0, errors.New("error") }, write: func([]byte) (int, error) { return 0, errors.New("error") }, }, nil } }, ).ts // Set a short delay in RoundTrip to give the persistConn time to notice // the connection is broken. We want to exercise the path where writeLoop exits // before it reads the request to send. If this delay is too short, we may instead // exercise the path where writeLoop accepts the request and then fails to write it. // That's fine, so long as we get the desired path often enough. SetEnterRoundTripHook(func() { time.Sleep(1 * time.Millisecond) }) defer SetEnterRoundTripHook(nil) var closes int _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) if err == nil { t.Fatalf("expected request to fail, but it did not") } if closes != 1 { t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes) } } // logWritesConn is a net.Conn that logs each Write call to writes // and then proxies to w. // It proxies Read calls to a reader it receives from rch. type logWritesConn struct { net.Conn // nil. crash on use. w io.Writer rch <-chan io.Reader r io.Reader // nil until received by rch mu sync.Mutex writes []string } func (c *logWritesConn) Write(p []byte) (n int, err error) { c.mu.Lock() defer c.mu.Unlock() c.writes = append(c.writes, string(p)) return c.w.Write(p) } func (c *logWritesConn) Read(p []byte) (n int, err error) { if c.r == nil { c.r = <-c.rch } return c.r.Read(p) } func (c *logWritesConn) Close() error { return nil } // Issue 6574 func TestTransportFlushesBodyChunks(t *testing.T) { defer afterTest(t) resBody := make(chan io.Reader, 1) connr, connw := io.Pipe() // connection pipe pair lw := &logWritesConn{ rch: resBody, w: connw, } tr := &Transport{ Dial: func(network, addr string) (net.Conn, error) { return lw, nil }, } bodyr, bodyw := io.Pipe() // body pipe pair go func() { defer bodyw.Close() for i := 0; i < 3; i++ { fmt.Fprintf(bodyw, "num%d\n", i) } }() resc := make(chan *Response) go func() { req, _ := NewRequest("POST", "http://localhost:8080", bodyr) req.Header.Set("User-Agent", "x") // known value for test res, err := tr.RoundTrip(req) if err != nil { t.Errorf("RoundTrip: %v", err) close(resc) return } resc <- res }() // Fully consume the request before checking the Write log vs. want. req, err := ReadRequest(bufio.NewReader(connr)) if err != nil { t.Fatal(err) } io.Copy(io.Discard, req.Body) // Unblock the transport's roundTrip goroutine. resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") res, ok := <-resc if !ok { return } defer res.Body.Close() want := []string{ "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n", "5\r\nnum0\n\r\n", "5\r\nnum1\n\r\n", "5\r\nnum2\n\r\n", "0\r\n\r\n", } if !reflect.DeepEqual(lw.writes, want) { t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want) } } // Issue 22088: flush Transport request headers if we're not sure the body won't block on read. func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) } func testTransportFlushesRequestHeader(t *testing.T, mode testMode) { gotReq := make(chan struct{}) cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { close(gotReq) })) pr, pw := io.Pipe() req, err := NewRequest("POST", cst.ts.URL, pr) if err != nil { t.Fatal(err) } gotRes := make(chan struct{}) go func() { defer close(gotRes) res, err := cst.tr.RoundTrip(req) if err != nil { t.Error(err) return } res.Body.Close() }() <-gotReq pw.Close() <-gotRes } type wgReadCloser struct { io.Reader wg *sync.WaitGroup closed bool } func (c *wgReadCloser) Close() error { if c.closed { return net.ErrClosed } c.closed = true c.wg.Done() return nil } // Issue 11745. func TestTransportPrefersResponseOverWriteError(t *testing.T) { // Not parallel: modifies the global rstAvoidanceDelay. run(t, testTransportPrefersResponseOverWriteError, testNotParallel) } func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } runTimeSensitiveTest(t, []time.Duration{ 1 * time.Millisecond, 5 * time.Millisecond, 10 * time.Millisecond, 50 * time.Millisecond, 100 * time.Millisecond, 500 * time.Millisecond, time.Second, 5 * time.Second, }, func(t *testing.T, timeout time.Duration) error { SetRSTAvoidanceDelay(t, timeout) t.Logf("set RST avoidance delay to %v", timeout) const contentLengthLimit = 1024 * 1024 // 1MB cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.ContentLength >= contentLengthLimit { w.WriteHeader(StatusBadRequest) r.Body.Close() return } w.WriteHeader(StatusOK) })) // We need to close cst explicitly here so that in-flight server // requests don't race with the call to SetRSTAvoidanceDelay for a retry. defer cst.close() ts := cst.ts c := ts.Client() count := 100 bigBody := strings.Repeat("a", contentLengthLimit*2) var wg sync.WaitGroup defer wg.Wait() getBody := func() (io.ReadCloser, error) { wg.Add(1) body := &wgReadCloser{ Reader: strings.NewReader(bigBody), wg: &wg, } return body, nil } for i := 0; i < count; i++ { reqBody, _ := getBody() req, err := NewRequest("PUT", ts.URL, reqBody) if err != nil { reqBody.Close() t.Fatal(err) } req.ContentLength = int64(len(bigBody)) req.GetBody = getBody resp, err := c.Do(req) if err != nil { return fmt.Errorf("Do %d: %v", i, err) } else { resp.Body.Close() if resp.StatusCode != 400 { t.Errorf("Expected status code 400, got %v", resp.Status) } } } return nil }) } func TestTransportAutomaticHTTP2(t *testing.T) { testTransportAutoHTTP(t, &Transport{}, true) } func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) { testTransportAutoHTTP(t, &Transport{ ForceAttemptHTTP2: true, TLSClientConfig: new(tls.Config), }, true) } // golang.org/issue/14391: also check DefaultTransport func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) { testTransportAutoHTTP(t, DefaultTransport.(*Transport), true) } func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) { testTransportAutoHTTP(t, &Transport{ TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper), }, false) } func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) { testTransportAutoHTTP(t, &Transport{ TLSClientConfig: new(tls.Config), }, false) } func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) { testTransportAutoHTTP(t, &Transport{ ExpectContinueTimeout: 1 * time.Second, }, true) } func TestTransportAutomaticHTTP2_Dial(t *testing.T) { var d net.Dialer testTransportAutoHTTP(t, &Transport{ Dial: d.Dial, }, false) } func TestTransportAutomaticHTTP2_DialContext(t *testing.T) { var d net.Dialer testTransportAutoHTTP(t, &Transport{ DialContext: d.DialContext, }, false) } func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) { testTransportAutoHTTP(t, &Transport{ DialTLS: func(network, addr string) (net.Conn, error) { panic("unused") }, }, false) } func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) { CondSkipHTTP2(t) _, err := tr.RoundTrip(new(Request)) if err == nil { t.Error("expected error from RoundTrip") } if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 { t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2) } } // Issue 13633: there was a race where we returned bodyless responses // to callers before recycling the persistent connection, which meant // a client doing two subsequent requests could end up on different // connections. It's somewhat harmless but enough tests assume it's // not true in order to test other things that it's worth fixing. // Plus it's nice to be consistent and not have timing-dependent // behavior. func TestTransportReuseConnEmptyResponseBody(t *testing.T) { run(t, testTransportReuseConnEmptyResponseBody) } func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) { cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Addr", r.RemoteAddr) // Empty response body. })) n := 100 if testing.Short() { n = 10 } var firstAddr string for i := 0; i < n; i++ { res, err := cst.c.Get(cst.ts.URL) if err != nil { log.Fatal(err) } addr := res.Header.Get("X-Addr") if i == 0 { firstAddr = addr } else if addr != firstAddr { t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr) } res.Body.Close() } } // Issue 13839 func TestNoCrashReturningTransportAltConn(t *testing.T) { cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) if err != nil { t.Fatal(err) } ln := newLocalListener(t) defer ln.Close() var wg sync.WaitGroup SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) defer SetPendingDialHooks(nil, nil) testDone := make(chan struct{}) defer close(testDone) go func() { tln := tls.NewListener(ln, &tls.Config{ NextProtos: []string{"foo"}, Certificates: []tls.Certificate{cert}, }) sc, err := tln.Accept() if err != nil { t.Error(err) return } if err := sc.(*tls.Conn).Handshake(); err != nil { t.Error(err) return } <-testDone sc.Close() }() addr := ln.Addr().String() req, _ := NewRequest("GET", "https://fake.tld/", nil) cancel := make(chan struct{}) req.Cancel = cancel doReturned := make(chan bool, 1) madeRoundTripper := make(chan bool, 1) tr := &Transport{ DisableKeepAlives: true, TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{ "foo": func(authority string, c *tls.Conn) RoundTripper { madeRoundTripper <- true return funcRoundTripper(func() { t.Error("foo RoundTripper should not be called") }) }, }, Dial: func(_, _ string) (net.Conn, error) { panic("shouldn't be called") }, DialTLS: func(_, _ string) (net.Conn, error) { tc, err := tls.Dial("tcp", addr, &tls.Config{ InsecureSkipVerify: true, NextProtos: []string{"foo"}, }) if err != nil { return nil, err } if err := tc.Handshake(); err != nil { return nil, err } close(cancel) <-doReturned return tc, nil }, } c := &Client{Transport: tr} _, err = c.Do(req) if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn { t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err) } doReturned <- true <-madeRoundTripper wg.Wait() } func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) { run(t, func(t *testing.T, mode testMode) { testTransportReuseConnection_Gzip(t, mode, true) }) } func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) { run(t, func(t *testing.T, mode testMode) { testTransportReuseConnection_Gzip(t, mode, false) }) } // Make sure we re-use underlying TCP connection for gzipped responses too. func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) { addr := make(chan string, 2) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { addr <- r.RemoteAddr w.Header().Set("Content-Encoding", "gzip") if chunked { w.(Flusher).Flush() } w.Write(rgz) // arbitrary gzip response })).ts c := ts.Client() trace := &httptrace.ClientTrace{ GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) }, GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) }, PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) }, ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) }, ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) }, } ctx := httptrace.WithClientTrace(context.Background(), trace) for i := 0; i < 2; i++ { req, _ := NewRequest("GET", ts.URL, nil) req = req.WithContext(ctx) res, err := c.Do(req) if err != nil { t.Fatal(err) } buf := make([]byte, len(rgz)) if n, err := io.ReadFull(res.Body, buf); err != nil { t.Errorf("%d. ReadFull = %v, %v", i, n, err) } // Note: no res.Body.Close call. It should work without it, // since the flate.Reader's internal buffering will hit EOF // and that should be sufficient. } a1, a2 := <-addr, <-addr if a1 != a2 { t.Fatalf("didn't reuse connection") } } func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) } func testTransportResponseHeaderLength(t *testing.T, mode testMode) { if mode == http2Mode { t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes") } ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.URL.Path == "/long" { w.Header().Set("Long", strings.Repeat("a", 1<<20)) } })).ts c := ts.Client() c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10 if res, err := c.Get(ts.URL); err != nil { t.Fatal(err) } else { res.Body.Close() } res, err := c.Get(ts.URL + "/long") if err == nil { defer res.Body.Close() var n int64 for k, vv := range res.Header { for _, v := range vv { n += int64(len(k)) + int64(len(v)) } } t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n) } if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) { t.Errorf("got error: %v; want %q", err, want) } } func TestTransportEventTrace(t *testing.T) { run(t, func(t *testing.T, mode testMode) { testTransportEventTrace(t, mode, false) }, testNotParallel) } // test a non-nil httptrace.ClientTrace but with all hooks set to zero. func TestTransportEventTrace_NoHooks(t *testing.T) { run(t, func(t *testing.T, mode testMode) { testTransportEventTrace(t, mode, true) }, testNotParallel) } func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) { const resBody = "some body" gotWroteReqEvent := make(chan struct{}, 500) cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method == "GET" { // Do nothing for the second request. return } if _, err := io.ReadAll(r.Body); err != nil { t.Error(err) } if !noHooks { <-gotWroteReqEvent } io.WriteString(w, resBody) }), func(tr *Transport) { if tr.TLSClientConfig != nil { tr.TLSClientConfig.InsecureSkipVerify = true } }) defer cst.close() cst.tr.ExpectContinueTimeout = 1 * time.Second var mu sync.Mutex // guards buf var buf strings.Builder logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&buf, format, args...) buf.WriteByte('\n') } addrStr := cst.ts.Listener.Addr().String() ip, port, err := net.SplitHostPort(addrStr) if err != nil { t.Fatal(err) } // Install a fake DNS server. ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) { if host != "dns-is-faked.golang" { t.Errorf("unexpected DNS host lookup for %q/%q", network, host) return nil, nil } return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil }) body := "some body" req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body)) req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"} trace := &httptrace.ClientTrace{ GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) }, GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) }, GotFirstResponseByte: func() { logf("first response byte") }, PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) }, DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) }, DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) }, ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) }, ConnectDone: func(network, addr string, err error) { if err != nil { t.Errorf("ConnectDone: %v", err) } logf("ConnectDone: connected to %s %s = %v", network, addr, err) }, WroteHeaderField: func(key string, value []string) { logf("WroteHeaderField: %s: %v", key, value) }, WroteHeaders: func() { logf("WroteHeaders") }, Wait100Continue: func() { logf("Wait100Continue") }, Got100Continue: func() { logf("Got100Continue") }, WroteRequest: func(e httptrace.WroteRequestInfo) { logf("WroteRequest: %+v", e) gotWroteReqEvent <- struct{}{} }, } if mode == http2Mode { trace.TLSHandshakeStart = func() { logf("tls handshake start") } trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) { logf("tls handshake done. ConnectionState = %v \n err = %v", s, err) } } if noHooks { // zero out all func pointers, trying to get some path to crash *trace = httptrace.ClientTrace{} } req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) req.Header.Set("Expect", "100-continue") res, err := cst.c.Do(req) if err != nil { t.Fatal(err) } logf("got roundtrip.response") slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } logf("consumed body") if string(slurp) != resBody || res.StatusCode != 200 { t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody) } res.Body.Close() if noHooks { // Done at this point. Just testing a full HTTP // requests can happen with a trace pointing to a zero // ClientTrace, full of nil func pointers. return } mu.Lock() got := buf.String() mu.Unlock() wantOnce := func(sub string) { if strings.Count(got, sub) != 1 { t.Errorf("expected substring %q exactly once in output.", sub) } } wantOnceOrMore := func(sub string) { if strings.Count(got, sub) == 0 { t.Errorf("expected substring %q at least once in output.", sub) } } wantOnce("Getting conn for dns-is-faked.golang:" + port) wantOnce("DNS start: {Host:dns-is-faked.golang}") wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err: Coalesced:false}") wantOnce("got conn: {") wantOnceOrMore("Connecting to tcp " + addrStr) wantOnceOrMore("connected to tcp " + addrStr + " = ") wantOnce("Reused:false WasIdle:false IdleTime:0s") wantOnce("first response byte") if mode == http2Mode { wantOnce("tls handshake start") wantOnce("tls handshake done") } else { wantOnce("PutIdleConn = ") wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]") // TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the // WroteHeaderField hook is not yet implemented in h2.) wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port)) wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body))) wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]") wantOnce("WroteHeaderField: Accept-Encoding: [gzip]") } wantOnce("WroteHeaders") wantOnce("Wait100Continue") wantOnce("Got100Continue") wantOnce("WroteRequest: {Err:}") if strings.Contains(got, " to udp ") { t.Errorf("should not see UDP (DNS) connections") } if t.Failed() { t.Errorf("Output:\n%s", got) } // And do a second request: req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil) req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) res, err = cst.c.Do(req) if err != nil { t.Fatal(err) } if res.StatusCode != 200 { t.Fatal(res.Status) } res.Body.Close() mu.Lock() got = buf.String() mu.Unlock() sub := "Getting conn for dns-is-faked.golang:" if gotn, want := strings.Count(got, sub), 2; gotn != want { t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got) } } func TestTransportEventTraceTLSVerify(t *testing.T) { run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode}) } func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) { var mu sync.Mutex var buf strings.Builder logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&buf, format, args...) buf.WriteByte('\n') } ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Error("Unexpected request") }), func(ts *httptest.Server) { ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) { logf("%s", p) return len(p), nil }), "", 0) }).ts certpool := x509.NewCertPool() certpool.AddCert(ts.Certificate()) c := &Client{Transport: &Transport{ TLSClientConfig: &tls.Config{ ServerName: "dns-is-faked.golang", RootCAs: certpool, }, }} trace := &httptrace.ClientTrace{ TLSHandshakeStart: func() { logf("TLSHandshakeStart") }, TLSHandshakeDone: func(s tls.ConnectionState, err error) { logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err) }, } req, _ := NewRequest("GET", ts.URL, nil) req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) _, err := c.Do(req) if err == nil { t.Error("Expected request to fail TLS verification") } mu.Lock() got := buf.String() mu.Unlock() wantOnce := func(sub string) { if strings.Count(got, sub) != 1 { t.Errorf("expected substring %q exactly once in output.", sub) } } wantOnce("TLSHandshakeStart") wantOnce("TLSHandshakeDone") wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com") if t.Failed() { t.Errorf("Output:\n%s", got) } } var ( isDNSHijackedOnce sync.Once isDNSHijacked bool ) func skipIfDNSHijacked(t *testing.T) { // Skip this test if the user is using a shady/ISP // DNS server hijacking queries. // See issues 16732, 16716. isDNSHijackedOnce.Do(func() { addrs, _ := net.LookupHost("dns-should-not-resolve.golang") isDNSHijacked = len(addrs) != 0 }) if isDNSHijacked { t.Skip("skipping; test requires non-hijacking DNS server") } } func TestTransportEventTraceRealDNS(t *testing.T) { skipIfDNSHijacked(t) defer afterTest(t) tr := &Transport{} defer tr.CloseIdleConnections() c := &Client{Transport: tr} var mu sync.Mutex // guards buf var buf strings.Builder logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&buf, format, args...) buf.WriteByte('\n') } req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil) trace := &httptrace.ClientTrace{ DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) }, DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) }, ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) }, ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) }, } req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) resp, err := c.Do(req) if err == nil { resp.Body.Close() t.Fatal("expected error during DNS lookup") } mu.Lock() got := buf.String() mu.Unlock() wantSub := func(sub string) { if !strings.Contains(got, sub) { t.Errorf("expected substring %q in output.", sub) } } wantSub("DNSStart: {Host:dns-should-not-resolve.golang}") wantSub("DNSDone: {Addrs:[] Err:") if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") { t.Errorf("should not see Connect events") } if t.Failed() { t.Errorf("Output:\n%s", got) } } // Issue 14353: port can only contain digits. func TestTransportRejectsAlphaPort(t *testing.T) { res, err := Get("http://dummy.tld:123foo/bar") if err == nil { res.Body.Close() t.Fatal("unexpected success") } ue, ok := err.(*url.Error) if !ok { t.Fatalf("got %#v; want *url.Error", err) } got := ue.Err.Error() want := `invalid port ":123foo" after host` if got != want { t.Errorf("got error %q; want %q", got, want) } } // Test the httptrace.TLSHandshake{Start,Done} hooks with an https http1 // connections. The http2 test is done in TestTransportEventTrace_h2 func TestTLSHandshakeTrace(t *testing.T) { run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode}) } func testTLSHandshakeTrace(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts var mu sync.Mutex var start, done bool trace := &httptrace.ClientTrace{ TLSHandshakeStart: func() { mu.Lock() defer mu.Unlock() start = true }, TLSHandshakeDone: func(s tls.ConnectionState, err error) { mu.Lock() defer mu.Unlock() done = true if err != nil { t.Fatal("Expected error to be nil but was:", err) } }, } c := ts.Client() req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal("Unable to construct test request:", err) } req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) r, err := c.Do(req) if err != nil { t.Fatal("Unexpected error making request:", err) } r.Body.Close() mu.Lock() defer mu.Unlock() if !start { t.Fatal("Expected TLSHandshakeStart to be called, but wasn't") } if !done { t.Fatal("Expected TLSHandshakeDone to be called, but wasn't") } } func TestTransportMaxIdleConns(t *testing.T) { run(t, testTransportMaxIdleConns, []testMode{http1Mode}) } func testTransportMaxIdleConns(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // No body for convenience. })).ts c := ts.Client() tr := c.Transport.(*Transport) tr.MaxIdleConns = 4 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String()) if err != nil { t.Fatal(err) } ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) { return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil }) hitHost := func(n int) { req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil) req = req.WithContext(ctx) res, err := c.Do(req) if err != nil { t.Fatal(err) } res.Body.Close() } for i := 0; i < 4; i++ { hitHost(i) } want := []string{ "|http|host-0.dns-is-faked.golang:" + port, "|http|host-1.dns-is-faked.golang:" + port, "|http|host-2.dns-is-faked.golang:" + port, "|http|host-3.dns-is-faked.golang:" + port, } if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) { t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want) } // Now hitting the 5th host should kick out the first host: hitHost(4) want = []string{ "|http|host-1.dns-is-faked.golang:" + port, "|http|host-2.dns-is-faked.golang:" + port, "|http|host-3.dns-is-faked.golang:" + port, "|http|host-4.dns-is-faked.golang:" + port, } if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) { t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want) } } func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) } func testTransportIdleConnTimeout(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } timeout := 1 * time.Millisecond timeoutLoop: for { cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // No body for convenience. })) tr := cst.tr tr.IdleConnTimeout = timeout defer tr.CloseIdleConnections() c := &Client{Transport: tr} idleConns := func() []string { if mode == http2Mode { return tr.IdleConnStrsForTesting_h2() } else { return tr.IdleConnStrsForTesting() } } var conn string doReq := func(n int) (timeoutOk bool) { req, _ := NewRequest("GET", cst.ts.URL, nil) req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ PutIdleConn: func(err error) { if err != nil { t.Errorf("failed to keep idle conn: %v", err) } }, })) res, err := c.Do(req) if err != nil { if strings.Contains(err.Error(), "use of closed network connection") { t.Logf("req %v: connection closed prematurely", n) return false } } res.Body.Close() conns := idleConns() if len(conns) != 1 { if len(conns) == 0 { t.Logf("req %v: no idle conns", n) return false } t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns) } if conn == "" { conn = conns[0] } if conn != conns[0] { t.Logf("req %v: cached connection changed; expected the same one throughout the test", n) return false } return true } for i := 0; i < 3; i++ { if !doReq(i) { t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout) timeout *= 2 cst.close() continue timeoutLoop } time.Sleep(timeout / 2) } waitCondition(t, timeout/2, func(d time.Duration) bool { if got := idleConns(); len(got) != 0 { if d >= timeout*3/2 { t.Logf("after %v, idle conns = %q", d, got) } return false } return true }) break } } // Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an // HTTP/2 connection was established but its caller no longer // wanted it. (Assuming the connection cache was enabled, which it is // by default) // // This test reproduced the crash by setting the IdleConnTimeout low // (to make the test reasonable) and then making a request which is // canceled by the DialTLS hook, which then also waits to return the // real connection until after the RoundTrip saw the error. Then we // know the successful tls.Dial from DialTLS will need to go into the // idle pool. Then we give it a of time to explode. func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) } func testIdleConnH2Crash(t *testing.T, mode testMode) { cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // nothing })) ctx, cancel := context.WithCancel(context.Background()) defer cancel() sawDoErr := make(chan bool, 1) testDone := make(chan struct{}) defer close(testDone) cst.tr.IdleConnTimeout = 5 * time.Millisecond cst.tr.DialTLS = func(network, addr string) (net.Conn, error) { c, err := tls.Dial(network, addr, &tls.Config{ InsecureSkipVerify: true, NextProtos: []string{"h2"}, }) if err != nil { t.Error(err) return nil, err } if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" { t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2") c.Close() return nil, errors.New("bogus") } cancel() select { case <-sawDoErr: case <-testDone: } return c, nil } req, _ := NewRequest("GET", cst.ts.URL, nil) req = req.WithContext(ctx) res, err := cst.c.Do(req) if err == nil { res.Body.Close() t.Fatal("unexpected success") } sawDoErr <- true // Wait for the explosion. time.Sleep(cst.tr.IdleConnTimeout * 10) } type funcConn struct { net.Conn read func([]byte) (int, error) write func([]byte) (int, error) } func (c funcConn) Read(p []byte) (int, error) { return c.read(p) } func (c funcConn) Write(p []byte) (int, error) { return c.write(p) } func (c funcConn) Close() error { return nil } // Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek // back to the caller. func TestTransportReturnsPeekError(t *testing.T) { errValue := errors.New("specific error value") wrote := make(chan struct{}) var wroteOnce sync.Once tr := &Transport{ Dial: func(network, addr string) (net.Conn, error) { c := funcConn{ read: func([]byte) (int, error) { <-wrote return 0, errValue }, write: func(p []byte) (int, error) { wroteOnce.Do(func() { close(wrote) }) return len(p), nil }, } return c, nil }, } _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil)) if err != errValue { t.Errorf("error = %#v; want %v", err, errValue) } } // Issue 13835: international domain names should work func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) } func testTransportIDNA(t *testing.T, mode testMode) { const uniDomain = "гофер.го" const punyDomain = "xn--c1ae0ajs.xn--c1aw" var port string cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { want := punyDomain + ":" + port if r.Host != want { t.Errorf("Host header = %q; want %q", r.Host, want) } if mode == http2Mode { if r.TLS == nil { t.Errorf("r.TLS == nil") } else if r.TLS.ServerName != punyDomain { t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain) } } w.Header().Set("Hit-Handler", "1") }), func(tr *Transport) { if tr.TLSClientConfig != nil { tr.TLSClientConfig.InsecureSkipVerify = true } }) ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String()) if err != nil { t.Fatal(err) } // Install a fake DNS server. ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) { if host != punyDomain { t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain) return nil, nil } return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil }) req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil) trace := &httptrace.ClientTrace{ GetConn: func(hostPort string) { want := net.JoinHostPort(punyDomain, port) if hostPort != want { t.Errorf("getting conn for %q; want %q", hostPort, want) } }, DNSStart: func(e httptrace.DNSStartInfo) { if e.Host != punyDomain { t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain) } }, } req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) res, err := cst.tr.RoundTrip(req) if err != nil { t.Fatal(err) } defer res.Body.Close() if res.Header.Get("Hit-Handler") != "1" { out, err := httputil.DumpResponse(res, true) if err != nil { t.Fatal(err) } t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out) } } // Issue 13290: send User-Agent in proxy CONNECT func TestTransportProxyConnectHeader(t *testing.T) { run(t, testTransportProxyConnectHeader, []testMode{http1Mode}) } func testTransportProxyConnectHeader(t *testing.T, mode testMode) { reqc := make(chan *Request, 1) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "CONNECT" { t.Errorf("method = %q; want CONNECT", r.Method) } reqc <- r c, _, err := w.(Hijacker).Hijack() if err != nil { t.Errorf("Hijack: %v", err) return } c.Close() })).ts c := ts.Client() c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { return url.Parse(ts.URL) } c.Transport.(*Transport).ProxyConnectHeader = Header{ "User-Agent": {"foo"}, "Other": {"bar"}, } res, err := c.Get("https://dummy.tld/") // https to force a CONNECT if err == nil { res.Body.Close() t.Errorf("unexpected success") } r := <-reqc if got, want := r.Header.Get("User-Agent"), "foo"; got != want { t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) } if got, want := r.Header.Get("Other"), "bar"; got != want { t.Errorf("CONNECT request Other = %q; want %q", got, want) } } func TestTransportProxyGetConnectHeader(t *testing.T) { run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode}) } func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) { reqc := make(chan *Request, 1) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "CONNECT" { t.Errorf("method = %q; want CONNECT", r.Method) } reqc <- r c, _, err := w.(Hijacker).Hijack() if err != nil { t.Errorf("Hijack: %v", err) return } c.Close() })).ts c := ts.Client() c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { return url.Parse(ts.URL) } // These should be ignored: c.Transport.(*Transport).ProxyConnectHeader = Header{ "User-Agent": {"foo"}, "Other": {"bar"}, } c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) { return Header{ "User-Agent": {"foo2"}, "Other": {"bar2"}, }, nil } res, err := c.Get("https://dummy.tld/") // https to force a CONNECT if err == nil { res.Body.Close() t.Errorf("unexpected success") } r := <-reqc if got, want := r.Header.Get("User-Agent"), "foo2"; got != want { t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) } if got, want := r.Header.Get("Other"), "bar2"; got != want { t.Errorf("CONNECT request Other = %q; want %q", got, want) } } var errFakeRoundTrip = errors.New("fake roundtrip") type funcRoundTripper func() func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) { fn() return nil, errFakeRoundTrip } func wantBody(res *Response, err error, want string) error { if err != nil { return err } slurp, err := io.ReadAll(res.Body) if err != nil { return fmt.Errorf("error reading body: %v", err) } if string(slurp) != want { return fmt.Errorf("body = %q; want %q", slurp, want) } if err := res.Body.Close(); err != nil { return fmt.Errorf("body Close = %v", err) } return nil } func newLocalListener(t *testing.T) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { ln, err = net.Listen("tcp6", "[::1]:0") } if err != nil { t.Fatal(err) } return ln } type countCloseReader struct { n *int io.Reader } func (cr countCloseReader) Close() error { (*cr.n)++ return nil } // rgz is a gzip quine that uncompresses to itself. var rgz = []byte{ 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73, 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0, 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, 0x00, 0x00, } // Ensure that a missing status doesn't make the server panic // See Issue https://golang.org/issues/21701 func TestMissingStatusNoPanic(t *testing.T) { t.Parallel() const want = "unknown status code" ln := newLocalListener(t) addr := ln.Addr().String() done := make(chan bool) fullAddrURL := fmt.Sprintf("http://%s", addr) raw := "HTTP/1.1 400\r\n" + "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + "Content-Type: text/html; charset=utf-8\r\n" + "Content-Length: 10\r\n" + "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" + "Vary: Accept-Encoding\r\n\r\n" + "Aloha Olaa" go func() { defer close(done) conn, _ := ln.Accept() if conn != nil { io.WriteString(conn, raw) io.ReadAll(conn) conn.Close() } }() proxyURL, err := url.Parse(fullAddrURL) if err != nil { t.Fatalf("proxyURL: %v", err) } tr := &Transport{Proxy: ProxyURL(proxyURL)} req, _ := NewRequest("GET", "https://golang.org/", nil) res, err, panicked := doFetchCheckPanic(tr, req) if panicked { t.Error("panicked, expecting an error") } if res != nil && res.Body != nil { io.Copy(io.Discard, res.Body) res.Body.Close() } if err == nil || !strings.Contains(err.Error(), want) { t.Errorf("got=%v want=%q", err, want) } ln.Close() <-done } func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) { defer func() { if r := recover(); r != nil { panicked = true } }() res, err = tr.RoundTrip(req) return } // Issue 22330: do not allow the response body to be read when the status code // forbids a response body. func TestNoBodyOnChunked304Response(t *testing.T) { run(t, testNoBodyOnChunked304Response, []testMode{http1Mode}) } func testNoBodyOnChunked304Response(t *testing.T, mode testMode) { cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n")) buf.Flush() conn.Close() })) // Our test server above is sending back bogus data after the // response (the "0\r\n\r\n" part), which causes the Transport // code to log spam. Disable keep-alives so we never even try // to reuse the connection. cst.tr.DisableKeepAlives = true res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } if res.Body != NoBody { t.Errorf("Unexpected body on 304 response") } } type funcWriter func([]byte) (int, error) func (f funcWriter) Write(p []byte) (int, error) { return f(p) } type doneContext struct { context.Context err error } func (doneContext) Done() <-chan struct{} { c := make(chan struct{}) close(c) return c } func (d doneContext) Err() error { return d.err } // Issue 25852: Transport should check whether Context is done early. func TestTransportCheckContextDoneEarly(t *testing.T) { tr := &Transport{} req, _ := NewRequest("GET", "http://fake.example/", nil) wantErr := errors.New("some error") req = req.WithContext(doneContext{context.Background(), wantErr}) _, err := tr.RoundTrip(req) if err != wantErr { t.Errorf("error = %v; want %v", err, wantErr) } } // Issue 23399: verify that if a client request times out, the Transport's // conn is closed so that it's not reused. // // This is the test variant that times out before the server replies with // any response headers. func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode}) } func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) { timeout := 1 * time.Millisecond for { inHandler := make(chan bool) cancelHandler := make(chan struct{}) handlerDone := make(chan bool) cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-r.Context().Done() select { case <-cancelHandler: return case inHandler <- true: } defer func() { handlerDone <- true }() // Read from the conn until EOF to verify that it was correctly closed. conn, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) return } n, err := conn.Read([]byte{0}) if n != 0 || err != io.EOF { t.Errorf("unexpected Read result: %v, %v", n, err) } conn.Close() })) cst.c.Timeout = timeout _, err := cst.c.Get(cst.ts.URL) if err == nil { close(cancelHandler) t.Fatal("unexpected Get success") } tooSlow := time.NewTimer(timeout * 10) select { case <-tooSlow.C: // If we didn't get into the Handler, that probably means the builder was // just slow and the Get failed in that time but never made it to the // server. That's fine; we'll try again with a longer timeout. t.Logf("no handler seen in %v; retrying with longer timeout", timeout) close(cancelHandler) cst.close() timeout *= 2 continue case <-inHandler: tooSlow.Stop() <-handlerDone } break } } // Issue 23399: verify that if a client request times out, the Transport's // conn is closed so that it's not reused. // // This is the test variant that has the server send response headers // first, and time out during the write of the response body. func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode}) } func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) { inHandler := make(chan bool) cancelHandler := make(chan struct{}) handlerDone := make(chan bool) cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "100") w.(Flusher).Flush() select { case <-cancelHandler: return case inHandler <- true: } defer func() { handlerDone <- true }() conn, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) return } conn.Write([]byte("foo")) n, err := conn.Read([]byte{0}) // The error should be io.EOF or "read tcp // 127.0.0.1:35827->127.0.0.1:40290: read: connection // reset by peer" depending on timing. Really we just // care that it returns at all. But if it returns with // data, that's weird. if n != 0 || err == nil { t.Errorf("unexpected Read result: %v, %v", n, err) } conn.Close() })) // Set Timeout to something very long but non-zero to exercise // the codepaths that check for it. But rather than wait for it to fire // (which would make the test slow), we send on the req.Cancel channel instead, // which happens to exercise the same code paths. cst.c.Timeout = 24 * time.Hour // just to be non-zero, not to hit it. req, _ := NewRequest("GET", cst.ts.URL, nil) cancelReq := make(chan struct{}) req.Cancel = cancelReq res, err := cst.c.Do(req) if err != nil { close(cancelHandler) t.Fatalf("Get error: %v", err) } // Cancel the request while the handler is still blocked on sending to the // inHandler channel. Then read it until it fails, to verify that the // connection is broken before the handler itself closes it. close(cancelReq) got, err := io.ReadAll(res.Body) if err == nil { t.Errorf("unexpected success; read %q, nil", got) } // Now unblock the handler and wait for it to complete. <-inHandler <-handlerDone } func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode}) } func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) { done := make(chan struct{}) defer close(done) cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) return } defer conn.Close() io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n") bs := bufio.NewScanner(conn) bs.Scan() fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text())) <-done })) req, _ := NewRequest("GET", cst.ts.URL, nil) req.Header.Set("Upgrade", "foo") req.Header.Set("Connection", "upgrade") res, err := cst.c.Do(req) if err != nil { t.Fatal(err) } if res.StatusCode != 101 { t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header) } rwc, ok := res.Body.(io.ReadWriteCloser) if !ok { t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body) } defer rwc.Close() bs := bufio.NewScanner(rwc) if !bs.Scan() { t.Fatalf("expected readable input") } if got, want := bs.Text(), "Some buffered data"; got != want { t.Errorf("read %q; want %q", got, want) } io.WriteString(rwc, "echo\n") if !bs.Scan() { t.Fatalf("expected another line") } if got, want := bs.Text(), "ECHO"; got != want { t.Errorf("read %q; want %q", got, want) } } func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) } func testTransportCONNECTBidi(t *testing.T, mode testMode) { const target = "backend:443" cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "CONNECT" { t.Errorf("unexpected method %q", r.Method) w.WriteHeader(500) return } if r.RequestURI != target { t.Errorf("unexpected CONNECT target %q", r.RequestURI) w.WriteHeader(500) return } nc, brw, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) return } defer nc.Close() nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n")) // Switch to a little protocol that capitalize its input lines: for { line, err := brw.ReadString('\n') if err != nil { if err != io.EOF { t.Error(err) } return } io.WriteString(brw, strings.ToUpper(line)) brw.Flush() } })) pr, pw := io.Pipe() defer pw.Close() req, err := NewRequest("CONNECT", cst.ts.URL, pr) if err != nil { t.Fatal(err) } req.URL.Opaque = target res, err := cst.c.Do(req) if err != nil { t.Fatal(err) } defer res.Body.Close() if res.StatusCode != 200 { t.Fatalf("status code = %d; want 200", res.StatusCode) } br := bufio.NewReader(res.Body) for _, str := range []string{"foo", "bar", "baz"} { fmt.Fprintf(pw, "%s\n", str) got, err := br.ReadString('\n') if err != nil { t.Fatal(err) } got = strings.TrimSpace(got) want := strings.ToUpper(str) if got != want { t.Fatalf("got %q; want %q", got, want) } } } func TestTransportRequestReplayable(t *testing.T) { someBody := io.NopCloser(strings.NewReader("")) tests := []struct { name string req *Request want bool }{ { name: "GET", req: &Request{Method: "GET"}, want: true, }, { name: "GET_http.NoBody", req: &Request{Method: "GET", Body: NoBody}, want: true, }, { name: "GET_body", req: &Request{Method: "GET", Body: someBody}, want: false, }, { name: "POST", req: &Request{Method: "POST"}, want: false, }, { name: "POST_idempotency-key", req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}}, want: true, }, { name: "POST_x-idempotency-key", req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}}, want: true, }, { name: "POST_body", req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody}, want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := tt.req.ExportIsReplayable() if got != tt.want { t.Errorf("replyable = %v; want %v", got, tt.want) } }) } } // testMockTCPConn is a mock TCP connection used to test that // ReadFrom is called when sending the request body. type testMockTCPConn struct { *net.TCPConn ReadFromCalled bool } func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) { c.ReadFromCalled = true return c.TCPConn.ReadFrom(r) } func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) } func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) { nBytes := int64(1 << 10) newFileFunc := func() (r io.Reader, done func(), err error) { f, err := os.CreateTemp("", "net-http-newfilefunc") if err != nil { return nil, nil, err } // Write some bytes to the file to enable reading. if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { return nil, nil, fmt.Errorf("failed to write data to file: %v", err) } if _, err := f.Seek(0, 0); err != nil { return nil, nil, fmt.Errorf("failed to seek to front: %v", err) } done = func() { f.Close() os.Remove(f.Name()) } return f, done, nil } newBufferFunc := func() (io.Reader, func(), error) { return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil } cases := []struct { name string readerFunc func() (io.Reader, func(), error) contentLength int64 expectedReadFrom bool }{ { name: "file, length", readerFunc: newFileFunc, contentLength: nBytes, expectedReadFrom: true, }, { name: "file, no length", readerFunc: newFileFunc, }, { name: "file, negative length", readerFunc: newFileFunc, contentLength: -1, }, { name: "buffer", contentLength: nBytes, readerFunc: newBufferFunc, }, { name: "buffer, no length", readerFunc: newBufferFunc, }, { name: "buffer, length -1", contentLength: -1, readerFunc: newBufferFunc, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { r, cleanup, err := tc.readerFunc() if err != nil { t.Fatal(err) } defer cleanup() tConn := &testMockTCPConn{} trFunc := func(tr *Transport) { tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { var d net.Dialer conn, err := d.DialContext(ctx, network, addr) if err != nil { return nil, err } tcpConn, ok := conn.(*net.TCPConn) if !ok { return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr) } tConn.TCPConn = tcpConn return tConn, nil } } cst := newClientServerTest( t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.Copy(io.Discard, r.Body) r.Body.Close() w.WriteHeader(200) }), trFunc, ) req, err := NewRequest("PUT", cst.ts.URL, r) if err != nil { t.Fatal(err) } req.ContentLength = tc.contentLength req.Header.Set("Content-Type", "application/octet-stream") resp, err := cst.c.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Fatalf("status code = %d; want 200", resp.StatusCode) } expectedReadFrom := tc.expectedReadFrom if mode != http1Mode { expectedReadFrom = false } if !tConn.ReadFromCalled && expectedReadFrom { t.Fatalf("did not call ReadFrom") } if tConn.ReadFromCalled && !expectedReadFrom { t.Fatalf("ReadFrom was unexpectedly invoked") } }) } } func TestTransportClone(t *testing.T) { tr := &Transport{ Proxy: func(*Request) (*url.URL, error) { panic("") }, OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error { return nil }, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, Dial: func(network, addr string) (net.Conn, error) { panic("") }, DialTLS: func(network, addr string) (net.Conn, error) { panic("") }, DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, TLSClientConfig: new(tls.Config), TLSHandshakeTimeout: time.Second, DisableKeepAlives: true, DisableCompression: true, MaxIdleConns: 1, MaxIdleConnsPerHost: 1, MaxConnsPerHost: 1, IdleConnTimeout: time.Second, ResponseHeaderTimeout: time.Second, ExpectContinueTimeout: time.Second, ProxyConnectHeader: Header{}, GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil }, MaxResponseHeaderBytes: 1, ForceAttemptHTTP2: true, TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{ "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") }, }, ReadBufferSize: 1, WriteBufferSize: 1, } tr2 := tr.Clone() rv := reflect.ValueOf(tr2).Elem() rt := rv.Type() for i := 0; i < rt.NumField(); i++ { sf := rt.Field(i) if !token.IsExported(sf.Name) { continue } if rv.Field(i).IsZero() { t.Errorf("cloned field t2.%s is zero", sf.Name) } } if _, ok := tr2.TLSNextProto["foo"]; !ok { t.Errorf("cloned Transport lacked TLSNextProto 'foo' key") } // But test that a nil TLSNextProto is kept nil: tr = new(Transport) tr2 = tr.Clone() if tr2.TLSNextProto != nil { t.Errorf("Transport.TLSNextProto unexpected non-nil") } } func TestIs408(t *testing.T) { tests := []struct { in string want bool }{ {"HTTP/1.0 408", true}, {"HTTP/1.1 408", true}, {"HTTP/1.8 408", true}, {"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now. {"HTTP/1.1 408 ", true}, {"HTTP/1.1 40", false}, {"http/1.0 408", false}, {"HTTP/1-1 408", false}, } for _, tt := range tests { if got := Export_is408Message([]byte(tt.in)); got != tt.want { t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want) } } } func TestTransportIgnores408(t *testing.T) { run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel) } func testTransportIgnores408(t *testing.T, mode testMode) { // Not parallel. Relies on mutating the log package's global Output. defer log.SetOutput(log.Writer()) var logout strings.Builder log.SetOutput(&logout) const target = "backend:443" cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { nc, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) return } defer nc.Close() nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")) nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail })) req, err := NewRequest("GET", cst.ts.URL, nil) if err != nil { t.Fatal(err) } res, err := cst.c.Do(req) if err != nil { t.Fatal(err) } slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } if err != nil { t.Fatal(err) } if string(slurp) != "ok" { t.Fatalf("got %q; want ok", slurp) } waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool { if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 { if d > 0 { t.Logf("%v idle conns still present after %v", n, d) } return false } return true }) if got := logout.String(); got != "" { t.Fatalf("expected no log output; got: %s", got) } } func TestInvalidHeaderResponse(t *testing.T) { run(t, testInvalidHeaderResponse, []testMode{http1Mode}) } func testInvalidHeaderResponse(t *testing.T, mode testMode) { cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 200 OK\r\n" + "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + "Content-Type: text/html; charset=utf-8\r\n" + "Content-Length: 0\r\n" + "Foo : bar\r\n\r\n")) buf.Flush() conn.Close() })) res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } defer res.Body.Close() if v := res.Header.Get("Foo"); v != "" { t.Errorf(`unexpected "Foo" header: %q`, v) } if v := res.Header.Get("Foo "); v != "bar" { t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar") } } type bodyCloser bool func (bc *bodyCloser) Close() error { *bc = true return nil } func (bc *bodyCloser) Read(b []byte) (n int, err error) { return 0, io.EOF } // Issue 35015: ensure that Transport closes the body on any error // with an invalid request, as promised by Client.Do docs. func TestTransportClosesBodyOnInvalidRequests(t *testing.T) { run(t, testTransportClosesBodyOnInvalidRequests) } func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) { cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Errorf("Should not have been invoked") })).ts u, _ := url.Parse(cst.URL) tests := []struct { name string req *Request wantErr string }{ { name: "invalid method", req: &Request{ Method: " ", URL: u, }, wantErr: `invalid method " "`, }, { name: "nil URL", req: &Request{ Method: "GET", }, wantErr: `nil Request.URL`, }, { name: "invalid header key", req: &Request{ Method: "GET", Header: Header{"💡": {"emoji"}}, URL: u, }, wantErr: `invalid header field name "💡"`, }, { name: "invalid header value", req: &Request{ Method: "POST", Header: Header{"key": {"\x19"}}, URL: u, }, wantErr: `invalid header field value for "key"`, }, { name: "non HTTP(s) scheme", req: &Request{ Method: "POST", URL: &url.URL{Scheme: "faux"}, }, wantErr: `unsupported protocol scheme "faux"`, }, { name: "no Host in URL", req: &Request{ Method: "POST", URL: &url.URL{Scheme: "http"}, }, wantErr: `no Host in request URL`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var bc bodyCloser req := tt.req req.Body = &bc _, err := cst.Client().Do(tt.req) if err == nil { t.Fatal("Expected an error") } if !bc { t.Fatal("Expected body to have been closed") } if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) { t.Fatalf("Error mismatch: %q does not end with %q", g, w) } }) } } // breakableConn is a net.Conn wrapper with a Write method // that will fail when its brokenState is true. type breakableConn struct { net.Conn *brokenState } type brokenState struct { sync.Mutex broken bool } func (w *breakableConn) Write(b []byte) (n int, err error) { w.Lock() defer w.Unlock() if w.broken { return 0, errors.New("some write error") } return w.Conn.Write(b) } // Issue 34978: don't cache a broken HTTP/2 connection func TestDontCacheBrokenHTTP2Conn(t *testing.T) { run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode}) } func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) { cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog) var brokenState brokenState const numReqs = 5 var numDials, gotConns uint32 // atomic cst.tr.Dial = func(netw, addr string) (net.Conn, error) { atomic.AddUint32(&numDials, 1) c, err := net.Dial(netw, addr) if err != nil { t.Errorf("unexpected Dial error: %v", err) return nil, err } return &breakableConn{c, &brokenState}, err } for i := 1; i <= numReqs; i++ { brokenState.Lock() brokenState.broken = false brokenState.Unlock() // doBreak controls whether we break the TCP connection after the TLS // handshake (before the HTTP/2 handshake). We test a few failures // in a row followed by a final success. doBreak := i != numReqs ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ GotConn: func(info httptrace.GotConnInfo) { t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime) atomic.AddUint32(&gotConns, 1) }, TLSHandshakeDone: func(cfg tls.ConnectionState, err error) { brokenState.Lock() defer brokenState.Unlock() if doBreak { brokenState.broken = true } }, }) req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil) if err != nil { t.Fatal(err) } _, err = cst.c.Do(req) if doBreak != (err != nil) { t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err) } } if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want { t.Errorf("GotConn calls = %v; want %v", got, want) } if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want { t.Errorf("Dials = %v; want %v", got, want) } } // Issue 34941 // When the client has too many concurrent requests on a single connection, // http.http2noCachedConnError is reported on multiple requests. There should // only be one decrement regardless of the number of failures. func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode}) } func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) { CondSkipHTTP2(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("foo")) if err != nil { t.Fatalf("Write: %v", err) } }) ts := newClientServerTest(t, mode, h).ts c := ts.Client() tr := c.Transport.(*Transport) tr.MaxConnsPerHost = 1 errCh := make(chan error, 300) doReq := func() { resp, err := c.Get(ts.URL) if err != nil { errCh <- fmt.Errorf("request failed: %v", err) return } defer resp.Body.Close() _, err = io.ReadAll(resp.Body) if err != nil { errCh <- fmt.Errorf("read body failed: %v", err) } } var wg sync.WaitGroup for i := 0; i < 300; i++ { wg.Add(1) go func() { defer wg.Done() doReq() }() } wg.Wait() close(errCh) for err := range errCh { t.Errorf("error occurred: %v", err) } } // Issue 36820 // Test that we use the older backward compatible cancellation protocol // when a RoundTripper is registered via RegisterProtocol. func TestAltProtoCancellation(t *testing.T) { defer afterTest(t) tr := &Transport{} c := &Client{ Transport: tr, Timeout: time.Millisecond, } tr.RegisterProtocol("cancel", cancelProto{}) _, err := c.Get("cancel://bar.com/path") if err == nil { t.Error("request unexpectedly succeeded") } else if !strings.Contains(err.Error(), errCancelProto.Error()) { t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto) } } var errCancelProto = errors.New("canceled as expected") type cancelProto struct{} func (cancelProto) RoundTrip(req *Request) (*Response, error) { <-req.Cancel return nil, errCancelProto } type roundTripFunc func(r *Request) (*Response, error) func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) } // Issue 32441: body is not reset after ErrSkipAltProtocol func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) } func testIssue32441(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if n, _ := io.Copy(io.Discard, r.Body); n == 0 { t.Error("body length is zero") } })).ts c := ts.Client() c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) { // Draining body to trigger failure condition on actual request to server. if n, _ := io.Copy(io.Discard, r.Body); n == 0 { t.Error("body length is zero during round trip") } return nil, ErrSkipAltProtocol })) if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil { t.Error(err) } } // Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers // that contain a sign (eg. "+3"), per RFC 2616, Section 14.13. func TestTransportRejectsSignInContentLength(t *testing.T) { run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode}) } func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) { cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "+3") w.Write([]byte("abc")) })).ts c := cst.Client() res, err := c.Get(cst.URL) if err == nil || res != nil { t.Fatal("Expected a non-nil error and a nil http.Response") } if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) { t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want) } } // dumpConn is a net.Conn which writes to Writer and reads from Reader type dumpConn struct { io.Writer io.Reader } func (c *dumpConn) Close() error { return nil } func (c *dumpConn) LocalAddr() net.Addr { return nil } func (c *dumpConn) RemoteAddr() net.Addr { return nil } func (c *dumpConn) SetDeadline(t time.Time) error { return nil } func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } // delegateReader is a reader that delegates to another reader, // once it arrives on a channel. type delegateReader struct { c chan io.Reader r io.Reader // nil until received from c } func (r *delegateReader) Read(p []byte) (int, error) { if r.r == nil { var ok bool if r.r, ok = <-r.c; !ok { return 0, errors.New("delegate closed") } } return r.r.Read(p) } func testTransportRace(req *Request) { save := req.Body pr, pw := io.Pipe() defer pr.Close() defer pw.Close() dr := &delegateReader{c: make(chan io.Reader)} t := &Transport{ Dial: func(net, addr string) (net.Conn, error) { return &dumpConn{pw, dr}, nil }, } defer t.CloseIdleConnections() quitReadCh := make(chan struct{}) // Wait for the request before replying with a dummy response: go func() { defer close(quitReadCh) req, err := ReadRequest(bufio.NewReader(pr)) if err == nil { // Ensure all the body is read; otherwise // we'll get a partial dump. io.Copy(io.Discard, req.Body) req.Body.Close() } select { case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"): case quitReadCh <- struct{}{}: // Ensure delegate is closed so Read doesn't block forever. close(dr.c) } }() t.RoundTrip(req) // Ensure the reader returns before we reset req.Body to prevent // a data race on req.Body. pw.Close() <-quitReadCh req.Body = save } // Issue 37669 // Test that a cancellation doesn't result in a data race due to the writeLoop // goroutine being left running, if the caller mutates the processed Request // upon completion. func TestErrorWriteLoopRace(t *testing.T) { if testing.Short() { return } t.Parallel() for i := 0; i < 1000; i++ { delay := time.Duration(mrand.Intn(5)) * time.Millisecond ctx, cancel := context.WithTimeout(context.Background(), delay) defer cancel() r := bytes.NewBuffer(make([]byte, 10000)) req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r) if err != nil { t.Fatal(err) } testTransportRace(req) } } // Issue 41600 // Test that a new request which uses the connection of an active request // cannot cause it to be canceled as well. func TestCancelRequestWhenSharingConnection(t *testing.T) { run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode}) } func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) { reqc := make(chan chan struct{}, 2) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) { ch := make(chan struct{}, 1) reqc <- ch <-ch w.Header().Add("Content-Length", "0") })).ts client := ts.Client() transport := client.Transport.(*Transport) transport.MaxIdleConns = 1 transport.MaxConnsPerHost = 1 var wg sync.WaitGroup wg.Add(1) putidlec := make(chan chan struct{}, 1) reqerrc := make(chan error, 1) go func() { defer wg.Done() ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ PutIdleConn: func(error) { // Signal that the idle conn has been returned to the pool, // and wait for the order to proceed. ch := make(chan struct{}) putidlec <- ch close(putidlec) // panic if PutIdleConn runs twice for some reason <-ch }, }) req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil) res, err := client.Do(req) reqerrc <- err if err == nil { res.Body.Close() } }() // Wait for the first request to receive a response and return the // connection to the idle pool. r1c := <-reqc close(r1c) var idlec chan struct{} select { case err := <-reqerrc: if err != nil { t.Fatalf("request 1: got err %v, want nil", err) } idlec = <-putidlec case idlec = <-putidlec: } wg.Add(1) cancelctx, cancel := context.WithCancel(context.Background()) go func() { defer wg.Done() req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil) res, err := client.Do(req) if err == nil { res.Body.Close() } if !errors.Is(err, context.Canceled) { t.Errorf("request 2: got err %v, want Canceled", err) } // Unblock the first request. close(idlec) }() // Wait for the second request to arrive at the server, and then cancel // the request context. r2c := <-reqc cancel() <-idlec close(r2c) wg.Wait() } func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) } func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { go io.Copy(io.Discard, req.Body) panic(ErrAbortHandler) })).ts var wg sync.WaitGroup for i := 0; i < 2; i++ { wg.Add(1) go func() { defer wg.Done() for j := 0; j < 10; j++ { const reqLen = 6 * 1024 * 1024 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) req.ContentLength = reqLen resp, _ := ts.Client().Transport.RoundTrip(req) if resp != nil { resp.Body.Close() } } }() } wg.Wait() } func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) } func testRequestSanitization(t *testing.T, mode testMode) { if mode == http2Mode { // Remove this after updating x/net. t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2") } ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { if h, ok := req.Header["X-Evil"]; ok { t.Errorf("request has X-Evil header: %q", h) } })).ts req, _ := NewRequest("GET", ts.URL, nil) req.Host = "go.dev\r\nX-Evil:evil" resp, _ := ts.Client().Do(req) if resp != nil { resp.Body.Close() } } func TestProxyAuthHeader(t *testing.T) { // Not parallel: Sets an environment variable. run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel) } func testProxyAuthHeader(t *testing.T, mode testMode) { const username = "u" const password = "@/?!" cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { // Copy the Proxy-Authorization header to a new Request, // since Request.BasicAuth only parses the Authorization header. var r2 Request r2.Header = Header{ "Authorization": req.Header["Proxy-Authorization"], } gotuser, gotpass, ok := r2.BasicAuth() if !ok || gotuser != username || gotpass != password { t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password) } })) u, err := url.Parse(cst.ts.URL) if err != nil { t.Fatal(err) } u.User = url.UserPassword(username, password) t.Setenv("HTTP_PROXY", u.String()) cst.tr.Proxy = ProxyURL(u) resp, err := cst.c.Get("http://_/") if err != nil { t.Fatal(err) } resp.Body.Close() }