Source file src/net/http/clientserver_test.go

     1  // Copyright 2015 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode.
     6  
     7  package http_test
     8  
     9  import (
    10  	"bytes"
    11  	"compress/gzip"
    12  	"context"
    13  	"crypto/rand"
    14  	"crypto/sha1"
    15  	"crypto/tls"
    16  	"fmt"
    17  	"hash"
    18  	"io"
    19  	"log"
    20  	"net"
    21  	. "net/http"
    22  	"net/http/httptest"
    23  	"net/http/httptrace"
    24  	"net/http/httputil"
    25  	"net/textproto"
    26  	"net/url"
    27  	"os"
    28  	"reflect"
    29  	"runtime"
    30  	"sort"
    31  	"strings"
    32  	"sync"
    33  	"sync/atomic"
    34  	"testing"
    35  	"time"
    36  )
    37  
    38  type clientServerTest struct {
    39  	t  *testing.T
    40  	h2 bool
    41  	h  Handler
    42  	ts *httptest.Server
    43  	tr *Transport
    44  	c  *Client
    45  }
    46  
    47  func (t *clientServerTest) close() {
    48  	t.tr.CloseIdleConnections()
    49  	t.ts.Close()
    50  }
    51  
    52  func (t *clientServerTest) getURL(u string) string {
    53  	res, err := t.c.Get(u)
    54  	if err != nil {
    55  		t.t.Fatal(err)
    56  	}
    57  	defer res.Body.Close()
    58  	slurp, err := io.ReadAll(res.Body)
    59  	if err != nil {
    60  		t.t.Fatal(err)
    61  	}
    62  	return string(slurp)
    63  }
    64  
    65  func (t *clientServerTest) scheme() string {
    66  	if t.h2 {
    67  		return "https"
    68  	}
    69  	return "http"
    70  }
    71  
    72  const (
    73  	h1Mode = false
    74  	h2Mode = true
    75  )
    76  
    77  var optQuietLog = func(ts *httptest.Server) {
    78  	ts.Config.ErrorLog = quietLog
    79  }
    80  
    81  func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
    82  	return func(ts *httptest.Server) {
    83  		ts.Config.ErrorLog = lg
    84  	}
    85  }
    86  
    87  func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...any) *clientServerTest {
    88  	if h2 {
    89  		CondSkipHTTP2(t)
    90  	}
    91  	cst := &clientServerTest{
    92  		t:  t,
    93  		h2: h2,
    94  		h:  h,
    95  		tr: &Transport{},
    96  	}
    97  	cst.c = &Client{Transport: cst.tr}
    98  	cst.ts = httptest.NewUnstartedServer(h)
    99  
   100  	for _, opt := range opts {
   101  		switch opt := opt.(type) {
   102  		case func(*Transport):
   103  			opt(cst.tr)
   104  		case func(*httptest.Server):
   105  			opt(cst.ts)
   106  		default:
   107  			t.Fatalf("unhandled option type %T", opt)
   108  		}
   109  	}
   110  
   111  	if !h2 {
   112  		cst.ts.Start()
   113  		return cst
   114  	}
   115  	ExportHttp2ConfigureServer(cst.ts.Config, nil)
   116  	cst.ts.TLS = cst.ts.Config.TLSConfig
   117  	cst.ts.StartTLS()
   118  
   119  	cst.tr.TLSClientConfig = &tls.Config{
   120  		InsecureSkipVerify: true,
   121  	}
   122  	if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
   123  		t.Fatal(err)
   124  	}
   125  	return cst
   126  }
   127  
   128  // Testing the newClientServerTest helper itself.
   129  func TestNewClientServerTest(t *testing.T) {
   130  	var got struct {
   131  		sync.Mutex
   132  		log []string
   133  	}
   134  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
   135  		got.Lock()
   136  		defer got.Unlock()
   137  		got.log = append(got.log, r.Proto)
   138  	})
   139  	for _, v := range [2]bool{false, true} {
   140  		cst := newClientServerTest(t, v, h)
   141  		if _, err := cst.c.Head(cst.ts.URL); err != nil {
   142  			t.Fatal(err)
   143  		}
   144  		cst.close()
   145  	}
   146  	got.Lock() // no need to unlock
   147  	if want := []string{"HTTP/1.1", "HTTP/2.0"}; !reflect.DeepEqual(got.log, want) {
   148  		t.Errorf("got %q; want %q", got.log, want)
   149  	}
   150  }
   151  
   152  func TestChunkedResponseHeaders_h1(t *testing.T) { testChunkedResponseHeaders(t, h1Mode) }
   153  func TestChunkedResponseHeaders_h2(t *testing.T) { testChunkedResponseHeaders(t, h2Mode) }
   154  
   155  func testChunkedResponseHeaders(t *testing.T, h2 bool) {
   156  	defer afterTest(t)
   157  	log.SetOutput(io.Discard) // is noisy otherwise
   158  	defer log.SetOutput(os.Stderr)
   159  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   160  		w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
   161  		w.(Flusher).Flush()
   162  		fmt.Fprintf(w, "I am a chunked response.")
   163  	}))
   164  	defer cst.close()
   165  
   166  	res, err := cst.c.Get(cst.ts.URL)
   167  	if err != nil {
   168  		t.Fatalf("Get error: %v", err)
   169  	}
   170  	defer res.Body.Close()
   171  	if g, e := res.ContentLength, int64(-1); g != e {
   172  		t.Errorf("expected ContentLength of %d; got %d", e, g)
   173  	}
   174  	wantTE := []string{"chunked"}
   175  	if h2 {
   176  		wantTE = nil
   177  	}
   178  	if !reflect.DeepEqual(res.TransferEncoding, wantTE) {
   179  		t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
   180  	}
   181  	if got, haveCL := res.Header["Content-Length"]; haveCL {
   182  		t.Errorf("Unexpected Content-Length: %q", got)
   183  	}
   184  }
   185  
   186  type reqFunc func(c *Client, url string) (*Response, error)
   187  
   188  // h12Compare is a test that compares HTTP/1 and HTTP/2 behavior
   189  // against each other.
   190  type h12Compare struct {
   191  	Handler            func(ResponseWriter, *Request)    // required
   192  	ReqFunc            reqFunc                           // optional
   193  	CheckResponse      func(proto string, res *Response) // optional
   194  	EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize
   195  	Opts               []any
   196  }
   197  
   198  func (tt h12Compare) reqFunc() reqFunc {
   199  	if tt.ReqFunc == nil {
   200  		return (*Client).Get
   201  	}
   202  	return tt.ReqFunc
   203  }
   204  
   205  func (tt h12Compare) run(t *testing.T) {
   206  	setParallel(t)
   207  	cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...)
   208  	defer cst1.close()
   209  	cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...)
   210  	defer cst2.close()
   211  
   212  	res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
   213  	if err != nil {
   214  		t.Errorf("HTTP/1 request: %v", err)
   215  		return
   216  	}
   217  	res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
   218  	if err != nil {
   219  		t.Errorf("HTTP/2 request: %v", err)
   220  		return
   221  	}
   222  
   223  	if fn := tt.EarlyCheckResponse; fn != nil {
   224  		fn("HTTP/1.1", res1)
   225  		fn("HTTP/2.0", res2)
   226  	}
   227  
   228  	tt.normalizeRes(t, res1, "HTTP/1.1")
   229  	tt.normalizeRes(t, res2, "HTTP/2.0")
   230  	res1body, res2body := res1.Body, res2.Body
   231  
   232  	eres1 := mostlyCopy(res1)
   233  	eres2 := mostlyCopy(res2)
   234  	if !reflect.DeepEqual(eres1, eres2) {
   235  		t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
   236  			cst1.ts.URL, eres1, cst2.ts.URL, eres2)
   237  	}
   238  	if !reflect.DeepEqual(res1body, res2body) {
   239  		t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
   240  	}
   241  	if fn := tt.CheckResponse; fn != nil {
   242  		res1.Body, res2.Body = res1body, res2body
   243  		fn("HTTP/1.1", res1)
   244  		fn("HTTP/2.0", res2)
   245  	}
   246  }
   247  
   248  func mostlyCopy(r *Response) *Response {
   249  	c := *r
   250  	c.Body = nil
   251  	c.TransferEncoding = nil
   252  	c.TLS = nil
   253  	c.Request = nil
   254  	return &c
   255  }
   256  
   257  type slurpResult struct {
   258  	io.ReadCloser
   259  	body []byte
   260  	err  error
   261  }
   262  
   263  func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
   264  
   265  func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
   266  	if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
   267  		res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
   268  	} else {
   269  		t.Errorf("got %q response; want %q", res.Proto, wantProto)
   270  	}
   271  	slurp, err := io.ReadAll(res.Body)
   272  
   273  	res.Body.Close()
   274  	res.Body = slurpResult{
   275  		ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
   276  		body:       slurp,
   277  		err:        err,
   278  	}
   279  	for i, v := range res.Header["Date"] {
   280  		res.Header["Date"][i] = strings.Repeat("x", len(v))
   281  	}
   282  	if res.Request == nil {
   283  		t.Errorf("for %s, no request", wantProto)
   284  	}
   285  	if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
   286  		t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
   287  	}
   288  }
   289  
   290  // Issue 13532
   291  func TestH12_HeadContentLengthNoBody(t *testing.T) {
   292  	h12Compare{
   293  		ReqFunc: (*Client).Head,
   294  		Handler: func(w ResponseWriter, r *Request) {
   295  		},
   296  	}.run(t)
   297  }
   298  
   299  func TestH12_HeadContentLengthSmallBody(t *testing.T) {
   300  	h12Compare{
   301  		ReqFunc: (*Client).Head,
   302  		Handler: func(w ResponseWriter, r *Request) {
   303  			io.WriteString(w, "small")
   304  		},
   305  	}.run(t)
   306  }
   307  
   308  func TestH12_HeadContentLengthLargeBody(t *testing.T) {
   309  	h12Compare{
   310  		ReqFunc: (*Client).Head,
   311  		Handler: func(w ResponseWriter, r *Request) {
   312  			chunk := strings.Repeat("x", 512<<10)
   313  			for i := 0; i < 10; i++ {
   314  				io.WriteString(w, chunk)
   315  			}
   316  		},
   317  	}.run(t)
   318  }
   319  
   320  func TestH12_200NoBody(t *testing.T) {
   321  	h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
   322  }
   323  
   324  func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
   325  func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
   326  func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
   327  
   328  func testH12_noBody(t *testing.T, status int) {
   329  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   330  		w.WriteHeader(status)
   331  	}}.run(t)
   332  }
   333  
   334  func TestH12_SmallBody(t *testing.T) {
   335  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   336  		io.WriteString(w, "small body")
   337  	}}.run(t)
   338  }
   339  
   340  func TestH12_ExplicitContentLength(t *testing.T) {
   341  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   342  		w.Header().Set("Content-Length", "3")
   343  		io.WriteString(w, "foo")
   344  	}}.run(t)
   345  }
   346  
   347  func TestH12_FlushBeforeBody(t *testing.T) {
   348  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   349  		w.(Flusher).Flush()
   350  		io.WriteString(w, "foo")
   351  	}}.run(t)
   352  }
   353  
   354  func TestH12_FlushMidBody(t *testing.T) {
   355  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   356  		io.WriteString(w, "foo")
   357  		w.(Flusher).Flush()
   358  		io.WriteString(w, "bar")
   359  	}}.run(t)
   360  }
   361  
   362  func TestH12_Head_ExplicitLen(t *testing.T) {
   363  	h12Compare{
   364  		ReqFunc: (*Client).Head,
   365  		Handler: func(w ResponseWriter, r *Request) {
   366  			if r.Method != "HEAD" {
   367  				t.Errorf("unexpected method %q", r.Method)
   368  			}
   369  			w.Header().Set("Content-Length", "1235")
   370  		},
   371  	}.run(t)
   372  }
   373  
   374  func TestH12_Head_ImplicitLen(t *testing.T) {
   375  	h12Compare{
   376  		ReqFunc: (*Client).Head,
   377  		Handler: func(w ResponseWriter, r *Request) {
   378  			if r.Method != "HEAD" {
   379  				t.Errorf("unexpected method %q", r.Method)
   380  			}
   381  			io.WriteString(w, "foo")
   382  		},
   383  	}.run(t)
   384  }
   385  
   386  func TestH12_HandlerWritesTooLittle(t *testing.T) {
   387  	h12Compare{
   388  		Handler: func(w ResponseWriter, r *Request) {
   389  			w.Header().Set("Content-Length", "3")
   390  			io.WriteString(w, "12") // one byte short
   391  		},
   392  		CheckResponse: func(proto string, res *Response) {
   393  			sr, ok := res.Body.(slurpResult)
   394  			if !ok {
   395  				t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
   396  				return
   397  			}
   398  			if sr.err != io.ErrUnexpectedEOF {
   399  				t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
   400  			}
   401  			if string(sr.body) != "12" {
   402  				t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
   403  			}
   404  		},
   405  	}.run(t)
   406  }
   407  
   408  // Tests that the HTTP/1 and HTTP/2 servers prevent handlers from
   409  // writing more than they declared. This test does not test whether
   410  // the transport deals with too much data, though, since the server
   411  // doesn't make it possible to send bogus data. For those tests, see
   412  // transport_test.go (for HTTP/1) or x/net/http2/transport_test.go
   413  // (for HTTP/2).
   414  func TestH12_HandlerWritesTooMuch(t *testing.T) {
   415  	h12Compare{
   416  		Handler: func(w ResponseWriter, r *Request) {
   417  			w.Header().Set("Content-Length", "3")
   418  			w.(Flusher).Flush()
   419  			io.WriteString(w, "123")
   420  			w.(Flusher).Flush()
   421  			n, err := io.WriteString(w, "x") // too many
   422  			if n > 0 || err == nil {
   423  				t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err)
   424  			}
   425  		},
   426  	}.run(t)
   427  }
   428  
   429  // Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip.
   430  // Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298
   431  func TestH12_AutoGzip(t *testing.T) {
   432  	h12Compare{
   433  		Handler: func(w ResponseWriter, r *Request) {
   434  			if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
   435  				t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
   436  			}
   437  			w.Header().Set("Content-Encoding", "gzip")
   438  			gz := gzip.NewWriter(w)
   439  			io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
   440  			gz.Close()
   441  		},
   442  	}.run(t)
   443  }
   444  
   445  func TestH12_AutoGzip_Disabled(t *testing.T) {
   446  	h12Compare{
   447  		Opts: []any{
   448  			func(tr *Transport) { tr.DisableCompression = true },
   449  		},
   450  		Handler: func(w ResponseWriter, r *Request) {
   451  			fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
   452  			if ae := r.Header.Get("Accept-Encoding"); ae != "" {
   453  				t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
   454  			}
   455  		},
   456  	}.run(t)
   457  }
   458  
   459  // Test304Responses verifies that 304s don't declare that they're
   460  // chunking in their response headers and aren't allowed to produce
   461  // output.
   462  func Test304Responses_h1(t *testing.T) { test304Responses(t, h1Mode) }
   463  func Test304Responses_h2(t *testing.T) { test304Responses(t, h2Mode) }
   464  
   465  func test304Responses(t *testing.T, h2 bool) {
   466  	defer afterTest(t)
   467  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   468  		w.WriteHeader(StatusNotModified)
   469  		_, err := w.Write([]byte("illegal body"))
   470  		if err != ErrBodyNotAllowed {
   471  			t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
   472  		}
   473  	}))
   474  	defer cst.close()
   475  	res, err := cst.c.Get(cst.ts.URL)
   476  	if err != nil {
   477  		t.Fatal(err)
   478  	}
   479  	if len(res.TransferEncoding) > 0 {
   480  		t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
   481  	}
   482  	body, err := io.ReadAll(res.Body)
   483  	if err != nil {
   484  		t.Error(err)
   485  	}
   486  	if len(body) > 0 {
   487  		t.Errorf("got unexpected body %q", string(body))
   488  	}
   489  }
   490  
   491  func TestH12_ServerEmptyContentLength(t *testing.T) {
   492  	h12Compare{
   493  		Handler: func(w ResponseWriter, r *Request) {
   494  			w.Header()["Content-Type"] = []string{""}
   495  			io.WriteString(w, "<html><body>hi</body></html>")
   496  		},
   497  	}.run(t)
   498  }
   499  
   500  func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
   501  	h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
   502  }
   503  
   504  func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
   505  	h12requestContentLength(t, func() io.Reader { return nil }, 0)
   506  }
   507  
   508  func TestH12_RequestContentLength_Unknown(t *testing.T) {
   509  	h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
   510  }
   511  
   512  func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
   513  	h12Compare{
   514  		Handler: func(w ResponseWriter, r *Request) {
   515  			w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
   516  			fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
   517  		},
   518  		ReqFunc: func(c *Client, url string) (*Response, error) {
   519  			return c.Post(url, "text/plain", bodyfn())
   520  		},
   521  		CheckResponse: func(proto string, res *Response) {
   522  			if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
   523  				t.Errorf("Proto %q got length %q; want %q", proto, got, want)
   524  			}
   525  		},
   526  	}.run(t)
   527  }
   528  
   529  // Tests that closing the Request.Cancel channel also while still
   530  // reading the response body. Issue 13159.
   531  func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) }
   532  func TestCancelRequestMidBody_h2(t *testing.T) { testCancelRequestMidBody(t, h2Mode) }
   533  func testCancelRequestMidBody(t *testing.T, h2 bool) {
   534  	defer afterTest(t)
   535  	unblock := make(chan bool)
   536  	didFlush := make(chan bool, 1)
   537  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   538  		io.WriteString(w, "Hello")
   539  		w.(Flusher).Flush()
   540  		didFlush <- true
   541  		<-unblock
   542  		io.WriteString(w, ", world.")
   543  	}))
   544  	defer cst.close()
   545  	defer close(unblock)
   546  
   547  	req, _ := NewRequest("GET", cst.ts.URL, nil)
   548  	cancel := make(chan struct{})
   549  	req.Cancel = cancel
   550  
   551  	res, err := cst.c.Do(req)
   552  	if err != nil {
   553  		t.Fatal(err)
   554  	}
   555  	defer res.Body.Close()
   556  	<-didFlush
   557  
   558  	// Read a bit before we cancel. (Issue 13626)
   559  	// We should have "Hello" at least sitting there.
   560  	firstRead := make([]byte, 10)
   561  	n, err := res.Body.Read(firstRead)
   562  	if err != nil {
   563  		t.Fatal(err)
   564  	}
   565  	firstRead = firstRead[:n]
   566  
   567  	close(cancel)
   568  
   569  	rest, err := io.ReadAll(res.Body)
   570  	all := string(firstRead) + string(rest)
   571  	if all != "Hello" {
   572  		t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
   573  	}
   574  	if err != ExportErrRequestCanceled {
   575  		t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
   576  	}
   577  }
   578  
   579  // Tests that clients can send trailers to a server and that the server can read them.
   580  func TestTrailersClientToServer_h1(t *testing.T) { testTrailersClientToServer(t, h1Mode) }
   581  func TestTrailersClientToServer_h2(t *testing.T) { testTrailersClientToServer(t, h2Mode) }
   582  
   583  func testTrailersClientToServer(t *testing.T, h2 bool) {
   584  	defer afterTest(t)
   585  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   586  		var decl []string
   587  		for k := range r.Trailer {
   588  			decl = append(decl, k)
   589  		}
   590  		sort.Strings(decl)
   591  
   592  		slurp, err := io.ReadAll(r.Body)
   593  		if err != nil {
   594  			t.Errorf("Server reading request body: %v", err)
   595  		}
   596  		if string(slurp) != "foo" {
   597  			t.Errorf("Server read request body %q; want foo", slurp)
   598  		}
   599  		if r.Trailer == nil {
   600  			io.WriteString(w, "nil Trailer")
   601  		} else {
   602  			fmt.Fprintf(w, "decl: %v, vals: %s, %s",
   603  				decl,
   604  				r.Trailer.Get("Client-Trailer-A"),
   605  				r.Trailer.Get("Client-Trailer-B"))
   606  		}
   607  	}))
   608  	defer cst.close()
   609  
   610  	var req *Request
   611  	req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
   612  		eofReaderFunc(func() {
   613  			req.Trailer["Client-Trailer-A"] = []string{"valuea"}
   614  		}),
   615  		strings.NewReader("foo"),
   616  		eofReaderFunc(func() {
   617  			req.Trailer["Client-Trailer-B"] = []string{"valueb"}
   618  		}),
   619  	))
   620  	req.Trailer = Header{
   621  		"Client-Trailer-A": nil, //  to be set later
   622  		"Client-Trailer-B": nil, //  to be set later
   623  	}
   624  	req.ContentLength = -1
   625  	res, err := cst.c.Do(req)
   626  	if err != nil {
   627  		t.Fatal(err)
   628  	}
   629  	if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
   630  		t.Error(err)
   631  	}
   632  }
   633  
   634  // Tests that servers send trailers to a client and that the client can read them.
   635  func TestTrailersServerToClient_h1(t *testing.T)       { testTrailersServerToClient(t, h1Mode, false) }
   636  func TestTrailersServerToClient_h2(t *testing.T)       { testTrailersServerToClient(t, h2Mode, false) }
   637  func TestTrailersServerToClient_Flush_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, true) }
   638  func TestTrailersServerToClient_Flush_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, true) }
   639  
   640  func testTrailersServerToClient(t *testing.T, h2, flush bool) {
   641  	defer afterTest(t)
   642  	const body = "Some body"
   643  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   644  		w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
   645  		w.Header().Add("Trailer", "Server-Trailer-C")
   646  
   647  		io.WriteString(w, body)
   648  		if flush {
   649  			w.(Flusher).Flush()
   650  		}
   651  
   652  		// How handlers set Trailers: declare it ahead of time
   653  		// with the Trailer header, and then mutate the
   654  		// Header() of those values later, after the response
   655  		// has been written (we wrote to w above).
   656  		w.Header().Set("Server-Trailer-A", "valuea")
   657  		w.Header().Set("Server-Trailer-C", "valuec") // skipping B
   658  		w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
   659  	}))
   660  	defer cst.close()
   661  
   662  	res, err := cst.c.Get(cst.ts.URL)
   663  	if err != nil {
   664  		t.Fatal(err)
   665  	}
   666  
   667  	wantHeader := Header{
   668  		"Content-Type": {"text/plain; charset=utf-8"},
   669  	}
   670  	wantLen := -1
   671  	if h2 && !flush {
   672  		// In HTTP/1.1, any use of trailers forces HTTP/1.1
   673  		// chunking and a flush at the first write. That's
   674  		// unnecessary with HTTP/2's framing, so the server
   675  		// is able to calculate the length while still sending
   676  		// trailers afterwards.
   677  		wantLen = len(body)
   678  		wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
   679  	}
   680  	if res.ContentLength != int64(wantLen) {
   681  		t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
   682  	}
   683  
   684  	delete(res.Header, "Date") // irrelevant for test
   685  	if !reflect.DeepEqual(res.Header, wantHeader) {
   686  		t.Errorf("Header = %v; want %v", res.Header, wantHeader)
   687  	}
   688  
   689  	if got, want := res.Trailer, (Header{
   690  		"Server-Trailer-A": nil,
   691  		"Server-Trailer-B": nil,
   692  		"Server-Trailer-C": nil,
   693  	}); !reflect.DeepEqual(got, want) {
   694  		t.Errorf("Trailer before body read = %v; want %v", got, want)
   695  	}
   696  
   697  	if err := wantBody(res, nil, body); err != nil {
   698  		t.Fatal(err)
   699  	}
   700  
   701  	if got, want := res.Trailer, (Header{
   702  		"Server-Trailer-A": {"valuea"},
   703  		"Server-Trailer-B": nil,
   704  		"Server-Trailer-C": {"valuec"},
   705  	}); !reflect.DeepEqual(got, want) {
   706  		t.Errorf("Trailer after body read = %v; want %v", got, want)
   707  	}
   708  }
   709  
   710  // Don't allow a Body.Read after Body.Close. Issue 13648.
   711  func TestResponseBodyReadAfterClose_h1(t *testing.T) { testResponseBodyReadAfterClose(t, h1Mode) }
   712  func TestResponseBodyReadAfterClose_h2(t *testing.T) { testResponseBodyReadAfterClose(t, h2Mode) }
   713  
   714  func testResponseBodyReadAfterClose(t *testing.T, h2 bool) {
   715  	defer afterTest(t)
   716  	const body = "Some body"
   717  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   718  		io.WriteString(w, body)
   719  	}))
   720  	defer cst.close()
   721  	res, err := cst.c.Get(cst.ts.URL)
   722  	if err != nil {
   723  		t.Fatal(err)
   724  	}
   725  	res.Body.Close()
   726  	data, err := io.ReadAll(res.Body)
   727  	if len(data) != 0 || err == nil {
   728  		t.Fatalf("ReadAll returned %q, %v; want error", data, err)
   729  	}
   730  }
   731  
   732  func TestConcurrentReadWriteReqBody_h1(t *testing.T) { testConcurrentReadWriteReqBody(t, h1Mode) }
   733  func TestConcurrentReadWriteReqBody_h2(t *testing.T) { testConcurrentReadWriteReqBody(t, h2Mode) }
   734  func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) {
   735  	defer afterTest(t)
   736  	const reqBody = "some request body"
   737  	const resBody = "some response body"
   738  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   739  		var wg sync.WaitGroup
   740  		wg.Add(2)
   741  		didRead := make(chan bool, 1)
   742  		// Read in one goroutine.
   743  		go func() {
   744  			defer wg.Done()
   745  			data, err := io.ReadAll(r.Body)
   746  			if string(data) != reqBody {
   747  				t.Errorf("Handler read %q; want %q", data, reqBody)
   748  			}
   749  			if err != nil {
   750  				t.Errorf("Handler Read: %v", err)
   751  			}
   752  			didRead <- true
   753  		}()
   754  		// Write in another goroutine.
   755  		go func() {
   756  			defer wg.Done()
   757  			if !h2 {
   758  				// our HTTP/1 implementation intentionally
   759  				// doesn't permit writes during read (mostly
   760  				// due to it being undefined); if that is ever
   761  				// relaxed, change this.
   762  				<-didRead
   763  			}
   764  			io.WriteString(w, resBody)
   765  		}()
   766  		wg.Wait()
   767  	}))
   768  	defer cst.close()
   769  	req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
   770  	req.Header.Add("Expect", "100-continue") // just to complicate things
   771  	res, err := cst.c.Do(req)
   772  	if err != nil {
   773  		t.Fatal(err)
   774  	}
   775  	data, err := io.ReadAll(res.Body)
   776  	defer res.Body.Close()
   777  	if err != nil {
   778  		t.Fatal(err)
   779  	}
   780  	if string(data) != resBody {
   781  		t.Errorf("read %q; want %q", data, resBody)
   782  	}
   783  }
   784  
   785  func TestConnectRequest_h1(t *testing.T) { testConnectRequest(t, h1Mode) }
   786  func TestConnectRequest_h2(t *testing.T) { testConnectRequest(t, h2Mode) }
   787  func testConnectRequest(t *testing.T, h2 bool) {
   788  	defer afterTest(t)
   789  	gotc := make(chan *Request, 1)
   790  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   791  		gotc <- r
   792  	}))
   793  	defer cst.close()
   794  
   795  	u, err := url.Parse(cst.ts.URL)
   796  	if err != nil {
   797  		t.Fatal(err)
   798  	}
   799  
   800  	tests := []struct {
   801  		req  *Request
   802  		want string
   803  	}{
   804  		{
   805  			req: &Request{
   806  				Method: "CONNECT",
   807  				Header: Header{},
   808  				URL:    u,
   809  			},
   810  			want: u.Host,
   811  		},
   812  		{
   813  			req: &Request{
   814  				Method: "CONNECT",
   815  				Header: Header{},
   816  				URL:    u,
   817  				Host:   "example.com:123",
   818  			},
   819  			want: "example.com:123",
   820  		},
   821  	}
   822  
   823  	for i, tt := range tests {
   824  		res, err := cst.c.Do(tt.req)
   825  		if err != nil {
   826  			t.Errorf("%d. RoundTrip = %v", i, err)
   827  			continue
   828  		}
   829  		res.Body.Close()
   830  		req := <-gotc
   831  		if req.Method != "CONNECT" {
   832  			t.Errorf("method = %q; want CONNECT", req.Method)
   833  		}
   834  		if req.Host != tt.want {
   835  			t.Errorf("Host = %q; want %q", req.Host, tt.want)
   836  		}
   837  		if req.URL.Host != tt.want {
   838  			t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
   839  		}
   840  	}
   841  }
   842  
   843  func TestTransportUserAgent_h1(t *testing.T) { testTransportUserAgent(t, h1Mode) }
   844  func TestTransportUserAgent_h2(t *testing.T) { testTransportUserAgent(t, h2Mode) }
   845  func testTransportUserAgent(t *testing.T, h2 bool) {
   846  	defer afterTest(t)
   847  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   848  		fmt.Fprintf(w, "%q", r.Header["User-Agent"])
   849  	}))
   850  	defer cst.close()
   851  
   852  	either := func(a, b string) string {
   853  		if h2 {
   854  			return b
   855  		}
   856  		return a
   857  	}
   858  
   859  	tests := []struct {
   860  		setup func(*Request)
   861  		want  string
   862  	}{
   863  		{
   864  			func(r *Request) {},
   865  			either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
   866  		},
   867  		{
   868  			func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
   869  			`["foo/1.2.3"]`,
   870  		},
   871  		{
   872  			func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
   873  			`["single"]`,
   874  		},
   875  		{
   876  			func(r *Request) { r.Header.Set("User-Agent", "") },
   877  			`[]`,
   878  		},
   879  		{
   880  			func(r *Request) { r.Header["User-Agent"] = nil },
   881  			`[]`,
   882  		},
   883  	}
   884  	for i, tt := range tests {
   885  		req, _ := NewRequest("GET", cst.ts.URL, nil)
   886  		tt.setup(req)
   887  		res, err := cst.c.Do(req)
   888  		if err != nil {
   889  			t.Errorf("%d. RoundTrip = %v", i, err)
   890  			continue
   891  		}
   892  		slurp, err := io.ReadAll(res.Body)
   893  		res.Body.Close()
   894  		if err != nil {
   895  			t.Errorf("%d. read body = %v", i, err)
   896  			continue
   897  		}
   898  		if string(slurp) != tt.want {
   899  			t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
   900  		}
   901  	}
   902  }
   903  
   904  func TestStarRequestFoo_h1(t *testing.T)     { testStarRequest(t, "FOO", h1Mode) }
   905  func TestStarRequestFoo_h2(t *testing.T)     { testStarRequest(t, "FOO", h2Mode) }
   906  func TestStarRequestOptions_h1(t *testing.T) { testStarRequest(t, "OPTIONS", h1Mode) }
   907  func TestStarRequestOptions_h2(t *testing.T) { testStarRequest(t, "OPTIONS", h2Mode) }
   908  func testStarRequest(t *testing.T, method string, h2 bool) {
   909  	defer afterTest(t)
   910  	gotc := make(chan *Request, 1)
   911  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   912  		w.Header().Set("foo", "bar")
   913  		gotc <- r
   914  		w.(Flusher).Flush()
   915  	}))
   916  	defer cst.close()
   917  
   918  	u, err := url.Parse(cst.ts.URL)
   919  	if err != nil {
   920  		t.Fatal(err)
   921  	}
   922  	u.Path = "*"
   923  
   924  	req := &Request{
   925  		Method: method,
   926  		Header: Header{},
   927  		URL:    u,
   928  	}
   929  
   930  	res, err := cst.c.Do(req)
   931  	if err != nil {
   932  		t.Fatalf("RoundTrip = %v", err)
   933  	}
   934  	res.Body.Close()
   935  
   936  	wantFoo := "bar"
   937  	wantLen := int64(-1)
   938  	if method == "OPTIONS" {
   939  		wantFoo = ""
   940  		wantLen = 0
   941  	}
   942  	if res.StatusCode != 200 {
   943  		t.Errorf("status code = %v; want %d", res.Status, 200)
   944  	}
   945  	if res.ContentLength != wantLen {
   946  		t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
   947  	}
   948  	if got := res.Header.Get("foo"); got != wantFoo {
   949  		t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
   950  	}
   951  	select {
   952  	case req = <-gotc:
   953  	default:
   954  		req = nil
   955  	}
   956  	if req == nil {
   957  		if method != "OPTIONS" {
   958  			t.Fatalf("handler never got request")
   959  		}
   960  		return
   961  	}
   962  	if req.Method != method {
   963  		t.Errorf("method = %q; want %q", req.Method, method)
   964  	}
   965  	if req.URL.Path != "*" {
   966  		t.Errorf("URL.Path = %q; want *", req.URL.Path)
   967  	}
   968  	if req.RequestURI != "*" {
   969  		t.Errorf("RequestURI = %q; want *", req.RequestURI)
   970  	}
   971  }
   972  
   973  // Issue 13957
   974  func TestTransportDiscardsUnneededConns(t *testing.T) {
   975  	setParallel(t)
   976  	defer afterTest(t)
   977  	cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   978  		fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
   979  	}))
   980  	defer cst.close()
   981  
   982  	var numOpen, numClose int32 // atomic
   983  
   984  	tlsConfig := &tls.Config{InsecureSkipVerify: true}
   985  	tr := &Transport{
   986  		TLSClientConfig: tlsConfig,
   987  		DialTLS: func(_, addr string) (net.Conn, error) {
   988  			time.Sleep(10 * time.Millisecond)
   989  			rc, err := net.Dial("tcp", addr)
   990  			if err != nil {
   991  				return nil, err
   992  			}
   993  			atomic.AddInt32(&numOpen, 1)
   994  			c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
   995  			return tls.Client(c, tlsConfig), nil
   996  		},
   997  	}
   998  	if err := ExportHttp2ConfigureTransport(tr); err != nil {
   999  		t.Fatal(err)
  1000  	}
  1001  	defer tr.CloseIdleConnections()
  1002  
  1003  	c := &Client{Transport: tr}
  1004  
  1005  	const N = 10
  1006  	gotBody := make(chan string, N)
  1007  	var wg sync.WaitGroup
  1008  	for i := 0; i < N; i++ {
  1009  		wg.Add(1)
  1010  		go func() {
  1011  			defer wg.Done()
  1012  			resp, err := c.Get(cst.ts.URL)
  1013  			if err != nil {
  1014  				// Try to work around spurious connection reset on loaded system.
  1015  				// See golang.org/issue/33585 and golang.org/issue/36797.
  1016  				time.Sleep(10 * time.Millisecond)
  1017  				resp, err = c.Get(cst.ts.URL)
  1018  				if err != nil {
  1019  					t.Errorf("Get: %v", err)
  1020  					return
  1021  				}
  1022  			}
  1023  			defer resp.Body.Close()
  1024  			slurp, err := io.ReadAll(resp.Body)
  1025  			if err != nil {
  1026  				t.Error(err)
  1027  			}
  1028  			gotBody <- string(slurp)
  1029  		}()
  1030  	}
  1031  	wg.Wait()
  1032  	close(gotBody)
  1033  
  1034  	var last string
  1035  	for got := range gotBody {
  1036  		if last == "" {
  1037  			last = got
  1038  			continue
  1039  		}
  1040  		if got != last {
  1041  			t.Errorf("Response body changed: %q -> %q", last, got)
  1042  		}
  1043  	}
  1044  
  1045  	var open, close int32
  1046  	for i := 0; i < 150; i++ {
  1047  		open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
  1048  		if open < 1 {
  1049  			t.Fatalf("open = %d; want at least", open)
  1050  		}
  1051  		if close == open-1 {
  1052  			// Success
  1053  			return
  1054  		}
  1055  		time.Sleep(10 * time.Millisecond)
  1056  	}
  1057  	t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
  1058  }
  1059  
  1060  // tests that Transport doesn't retain a pointer to the provided request.
  1061  func TestTransportGCRequest_Body_h1(t *testing.T)   { testTransportGCRequest(t, h1Mode, true) }
  1062  func TestTransportGCRequest_Body_h2(t *testing.T)   { testTransportGCRequest(t, h2Mode, true) }
  1063  func TestTransportGCRequest_NoBody_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, false) }
  1064  func TestTransportGCRequest_NoBody_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, false) }
  1065  func testTransportGCRequest(t *testing.T, h2, body bool) {
  1066  	setParallel(t)
  1067  	defer afterTest(t)
  1068  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1069  		io.ReadAll(r.Body)
  1070  		if body {
  1071  			io.WriteString(w, "Hello.")
  1072  		}
  1073  	}))
  1074  	defer cst.close()
  1075  
  1076  	didGC := make(chan struct{})
  1077  	(func() {
  1078  		body := strings.NewReader("some body")
  1079  		req, _ := NewRequest("POST", cst.ts.URL, body)
  1080  		runtime.SetFinalizer(req, func(*Request) { close(didGC) })
  1081  		res, err := cst.c.Do(req)
  1082  		if err != nil {
  1083  			t.Fatal(err)
  1084  		}
  1085  		if _, err := io.ReadAll(res.Body); err != nil {
  1086  			t.Fatal(err)
  1087  		}
  1088  		if err := res.Body.Close(); err != nil {
  1089  			t.Fatal(err)
  1090  		}
  1091  	})()
  1092  	timeout := time.NewTimer(5 * time.Second)
  1093  	defer timeout.Stop()
  1094  	for {
  1095  		select {
  1096  		case <-didGC:
  1097  			return
  1098  		case <-time.After(100 * time.Millisecond):
  1099  			runtime.GC()
  1100  		case <-timeout.C:
  1101  			t.Fatal("never saw GC of request")
  1102  		}
  1103  	}
  1104  }
  1105  
  1106  func TestTransportRejectsInvalidHeaders_h1(t *testing.T) {
  1107  	testTransportRejectsInvalidHeaders(t, h1Mode)
  1108  }
  1109  func TestTransportRejectsInvalidHeaders_h2(t *testing.T) {
  1110  	testTransportRejectsInvalidHeaders(t, h2Mode)
  1111  }
  1112  func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) {
  1113  	setParallel(t)
  1114  	defer afterTest(t)
  1115  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1116  		fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
  1117  	}), optQuietLog)
  1118  	defer cst.close()
  1119  	cst.tr.DisableKeepAlives = true
  1120  
  1121  	tests := []struct {
  1122  		key, val string
  1123  		ok       bool
  1124  	}{
  1125  		{"Foo", "capital-key", true}, // verify h2 allows capital keys
  1126  		{"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed
  1127  		{"Foo", "two\nlines", false}, // \n byte in value not allowed
  1128  		{"bogus\nkey", "v", false},   // \n byte also not allowed in key
  1129  		{"A space", "v", false},      // spaces in keys not allowed
  1130  		{"имя", "v", false},          // key must be ascii
  1131  		{"name", "валю", true},       // value may be non-ascii
  1132  		{"", "v", false},             // key must be non-empty
  1133  		{"k", "", true},              // value may be empty
  1134  	}
  1135  	for _, tt := range tests {
  1136  		dialedc := make(chan bool, 1)
  1137  		cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
  1138  			dialedc <- true
  1139  			return net.Dial(netw, addr)
  1140  		}
  1141  		req, _ := NewRequest("GET", cst.ts.URL, nil)
  1142  		req.Header[tt.key] = []string{tt.val}
  1143  		res, err := cst.c.Do(req)
  1144  		var body []byte
  1145  		if err == nil {
  1146  			body, _ = io.ReadAll(res.Body)
  1147  			res.Body.Close()
  1148  		}
  1149  		var dialed bool
  1150  		select {
  1151  		case <-dialedc:
  1152  			dialed = true
  1153  		default:
  1154  		}
  1155  
  1156  		if !tt.ok && dialed {
  1157  			t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
  1158  		} else if (err == nil) != tt.ok {
  1159  			t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
  1160  		}
  1161  	}
  1162  }
  1163  
  1164  func TestInterruptWithPanic_h1(t *testing.T)     { testInterruptWithPanic(t, h1Mode, "boom") }
  1165  func TestInterruptWithPanic_h2(t *testing.T)     { testInterruptWithPanic(t, h2Mode, "boom") }
  1166  func TestInterruptWithPanic_nil_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, nil) }
  1167  func TestInterruptWithPanic_nil_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, nil) }
  1168  func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) {
  1169  	testInterruptWithPanic(t, h1Mode, ErrAbortHandler)
  1170  }
  1171  func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) {
  1172  	testInterruptWithPanic(t, h2Mode, ErrAbortHandler)
  1173  }
  1174  func testInterruptWithPanic(t *testing.T, h2 bool, panicValue any) {
  1175  	setParallel(t)
  1176  	const msg = "hello"
  1177  	defer afterTest(t)
  1178  
  1179  	testDone := make(chan struct{})
  1180  	defer close(testDone)
  1181  
  1182  	var errorLog lockedBytesBuffer
  1183  	gotHeaders := make(chan bool, 1)
  1184  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1185  		io.WriteString(w, msg)
  1186  		w.(Flusher).Flush()
  1187  
  1188  		select {
  1189  		case <-gotHeaders:
  1190  		case <-testDone:
  1191  		}
  1192  		panic(panicValue)
  1193  	}), func(ts *httptest.Server) {
  1194  		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
  1195  	})
  1196  	defer cst.close()
  1197  	res, err := cst.c.Get(cst.ts.URL)
  1198  	if err != nil {
  1199  		t.Fatal(err)
  1200  	}
  1201  	gotHeaders <- true
  1202  	defer res.Body.Close()
  1203  	slurp, err := io.ReadAll(res.Body)
  1204  	if string(slurp) != msg {
  1205  		t.Errorf("client read %q; want %q", slurp, msg)
  1206  	}
  1207  	if err == nil {
  1208  		t.Errorf("client read all successfully; want some error")
  1209  	}
  1210  	logOutput := func() string {
  1211  		errorLog.Lock()
  1212  		defer errorLog.Unlock()
  1213  		return errorLog.String()
  1214  	}
  1215  	wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
  1216  
  1217  	if err := waitErrCondition(5*time.Second, 10*time.Millisecond, func() error {
  1218  		gotLog := logOutput()
  1219  		if !wantStackLogged {
  1220  			if gotLog == "" {
  1221  				return nil
  1222  			}
  1223  			return fmt.Errorf("want no log output; got: %s", gotLog)
  1224  		}
  1225  		if gotLog == "" {
  1226  			return fmt.Errorf("wanted a stack trace logged; got nothing")
  1227  		}
  1228  		if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
  1229  			return fmt.Errorf("output doesn't look like a panic stack trace. Got: %s", gotLog)
  1230  		}
  1231  		return nil
  1232  	}); err != nil {
  1233  		t.Fatal(err)
  1234  	}
  1235  }
  1236  
  1237  type lockedBytesBuffer struct {
  1238  	sync.Mutex
  1239  	bytes.Buffer
  1240  }
  1241  
  1242  func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
  1243  	b.Lock()
  1244  	defer b.Unlock()
  1245  	return b.Buffer.Write(p)
  1246  }
  1247  
  1248  // Issue 15366
  1249  func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
  1250  	h12Compare{
  1251  		Handler: func(w ResponseWriter, r *Request) {
  1252  			h := w.Header()
  1253  			h.Set("Content-Encoding", "gzip")
  1254  			h.Set("Content-Length", "23")
  1255  			io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
  1256  		},
  1257  		EarlyCheckResponse: func(proto string, res *Response) {
  1258  			if !res.Uncompressed {
  1259  				t.Errorf("%s: expected Uncompressed to be set", proto)
  1260  			}
  1261  			dump, err := httputil.DumpResponse(res, true)
  1262  			if err != nil {
  1263  				t.Errorf("%s: DumpResponse: %v", proto, err)
  1264  				return
  1265  			}
  1266  			if strings.Contains(string(dump), "Connection: close") {
  1267  				t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
  1268  			}
  1269  			if !strings.Contains(string(dump), "FOO") {
  1270  				t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
  1271  			}
  1272  		},
  1273  	}.run(t)
  1274  }
  1275  
  1276  // Issue 14607
  1277  func TestCloseIdleConnections_h1(t *testing.T) { testCloseIdleConnections(t, h1Mode) }
  1278  func TestCloseIdleConnections_h2(t *testing.T) { testCloseIdleConnections(t, h2Mode) }
  1279  func testCloseIdleConnections(t *testing.T, h2 bool) {
  1280  	setParallel(t)
  1281  	defer afterTest(t)
  1282  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1283  		w.Header().Set("X-Addr", r.RemoteAddr)
  1284  	}))
  1285  	defer cst.close()
  1286  	get := func() string {
  1287  		res, err := cst.c.Get(cst.ts.URL)
  1288  		if err != nil {
  1289  			t.Fatal(err)
  1290  		}
  1291  		res.Body.Close()
  1292  		v := res.Header.Get("X-Addr")
  1293  		if v == "" {
  1294  			t.Fatal("didn't get X-Addr")
  1295  		}
  1296  		return v
  1297  	}
  1298  	a1 := get()
  1299  	cst.tr.CloseIdleConnections()
  1300  	a2 := get()
  1301  	if a1 == a2 {
  1302  		t.Errorf("didn't close connection")
  1303  	}
  1304  }
  1305  
  1306  type noteCloseConn struct {
  1307  	net.Conn
  1308  	closeFunc func()
  1309  }
  1310  
  1311  func (x noteCloseConn) Close() error {
  1312  	x.closeFunc()
  1313  	return x.Conn.Close()
  1314  }
  1315  
  1316  type testErrorReader struct{ t *testing.T }
  1317  
  1318  func (r testErrorReader) Read(p []byte) (n int, err error) {
  1319  	r.t.Error("unexpected Read call")
  1320  	return 0, io.EOF
  1321  }
  1322  
  1323  func TestNoSniffExpectRequestBody_h1(t *testing.T) { testNoSniffExpectRequestBody(t, h1Mode) }
  1324  func TestNoSniffExpectRequestBody_h2(t *testing.T) { testNoSniffExpectRequestBody(t, h2Mode) }
  1325  
  1326  func testNoSniffExpectRequestBody(t *testing.T, h2 bool) {
  1327  	defer afterTest(t)
  1328  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1329  		w.WriteHeader(StatusUnauthorized)
  1330  	}))
  1331  	defer cst.close()
  1332  
  1333  	// Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it.
  1334  	cst.tr.ExpectContinueTimeout = 10 * time.Second
  1335  
  1336  	req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
  1337  	if err != nil {
  1338  		t.Fatal(err)
  1339  	}
  1340  	req.ContentLength = 0 // so transport is tempted to sniff it
  1341  	req.Header.Set("Expect", "100-continue")
  1342  	res, err := cst.tr.RoundTrip(req)
  1343  	if err != nil {
  1344  		t.Fatal(err)
  1345  	}
  1346  	defer res.Body.Close()
  1347  	if res.StatusCode != StatusUnauthorized {
  1348  		t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
  1349  	}
  1350  }
  1351  
  1352  func TestServerUndeclaredTrailers_h1(t *testing.T) { testServerUndeclaredTrailers(t, h1Mode) }
  1353  func TestServerUndeclaredTrailers_h2(t *testing.T) { testServerUndeclaredTrailers(t, h2Mode) }
  1354  func testServerUndeclaredTrailers(t *testing.T, h2 bool) {
  1355  	defer afterTest(t)
  1356  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1357  		w.Header().Set("Foo", "Bar")
  1358  		w.Header().Set("Trailer:Foo", "Baz")
  1359  		w.(Flusher).Flush()
  1360  		w.Header().Add("Trailer:Foo", "Baz2")
  1361  		w.Header().Set("Trailer:Bar", "Quux")
  1362  	}))
  1363  	defer cst.close()
  1364  	res, err := cst.c.Get(cst.ts.URL)
  1365  	if err != nil {
  1366  		t.Fatal(err)
  1367  	}
  1368  	if _, err := io.Copy(io.Discard, res.Body); err != nil {
  1369  		t.Fatal(err)
  1370  	}
  1371  	res.Body.Close()
  1372  	delete(res.Header, "Date")
  1373  	delete(res.Header, "Content-Type")
  1374  
  1375  	if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
  1376  		t.Errorf("Header = %#v; want %#v", res.Header, want)
  1377  	}
  1378  	if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
  1379  		t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
  1380  	}
  1381  }
  1382  
  1383  func TestBadResponseAfterReadingBody(t *testing.T) {
  1384  	defer afterTest(t)
  1385  	cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) {
  1386  		_, err := io.Copy(io.Discard, r.Body)
  1387  		if err != nil {
  1388  			t.Fatal(err)
  1389  		}
  1390  		c, _, err := w.(Hijacker).Hijack()
  1391  		if err != nil {
  1392  			t.Fatal(err)
  1393  		}
  1394  		defer c.Close()
  1395  		fmt.Fprintln(c, "some bogus crap")
  1396  	}))
  1397  	defer cst.close()
  1398  
  1399  	closes := 0
  1400  	res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
  1401  	if err == nil {
  1402  		res.Body.Close()
  1403  		t.Fatal("expected an error to be returned from Post")
  1404  	}
  1405  	if closes != 1 {
  1406  		t.Errorf("closes = %d; want 1", closes)
  1407  	}
  1408  }
  1409  
  1410  func TestWriteHeader0_h1(t *testing.T) { testWriteHeader0(t, h1Mode) }
  1411  func TestWriteHeader0_h2(t *testing.T) { testWriteHeader0(t, h2Mode) }
  1412  func testWriteHeader0(t *testing.T, h2 bool) {
  1413  	defer afterTest(t)
  1414  	gotpanic := make(chan bool, 1)
  1415  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1416  		defer close(gotpanic)
  1417  		defer func() {
  1418  			if e := recover(); e != nil {
  1419  				got := fmt.Sprintf("%T, %v", e, e)
  1420  				want := "string, invalid WriteHeader code 0"
  1421  				if got != want {
  1422  					t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
  1423  				}
  1424  				gotpanic <- true
  1425  
  1426  				// Set an explicit 503. This also tests that the WriteHeader call panics
  1427  				// before it recorded that an explicit value was set and that bogus
  1428  				// value wasn't stuck.
  1429  				w.WriteHeader(503)
  1430  			}
  1431  		}()
  1432  		w.WriteHeader(0)
  1433  	}))
  1434  	defer cst.close()
  1435  	res, err := cst.c.Get(cst.ts.URL)
  1436  	if err != nil {
  1437  		t.Fatal(err)
  1438  	}
  1439  	if res.StatusCode != 503 {
  1440  		t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
  1441  	}
  1442  	if !<-gotpanic {
  1443  		t.Error("expected panic in handler")
  1444  	}
  1445  }
  1446  
  1447  // Issue 23010: don't be super strict checking WriteHeader's code if
  1448  // it's not even valid to call WriteHeader then anyway.
  1449  func TestWriteHeaderNoCodeCheck_h1(t *testing.T)       { testWriteHeaderAfterWrite(t, h1Mode, false) }
  1450  func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, true) }
  1451  func TestWriteHeaderNoCodeCheck_h2(t *testing.T)       { testWriteHeaderAfterWrite(t, h2Mode, false) }
  1452  func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) {
  1453  	setParallel(t)
  1454  	defer afterTest(t)
  1455  
  1456  	var errorLog lockedBytesBuffer
  1457  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1458  		if hijack {
  1459  			conn, _, _ := w.(Hijacker).Hijack()
  1460  			defer conn.Close()
  1461  			conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
  1462  			w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
  1463  			conn.Write([]byte("bar"))
  1464  			return
  1465  		}
  1466  		io.WriteString(w, "foo")
  1467  		w.(Flusher).Flush()
  1468  		w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
  1469  		io.WriteString(w, "bar")
  1470  	}), func(ts *httptest.Server) {
  1471  		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
  1472  	})
  1473  	defer cst.close()
  1474  	res, err := cst.c.Get(cst.ts.URL)
  1475  	if err != nil {
  1476  		t.Fatal(err)
  1477  	}
  1478  	defer res.Body.Close()
  1479  	body, err := io.ReadAll(res.Body)
  1480  	if err != nil {
  1481  		t.Fatal(err)
  1482  	}
  1483  	if got, want := string(body), "foobar"; got != want {
  1484  		t.Errorf("got = %q; want %q", got, want)
  1485  	}
  1486  
  1487  	// Also check the stderr output:
  1488  	if h2 {
  1489  		// TODO: also emit this log message for HTTP/2?
  1490  		// We historically haven't, so don't check.
  1491  		return
  1492  	}
  1493  	gotLog := strings.TrimSpace(errorLog.String())
  1494  	wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
  1495  	if hijack {
  1496  		wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
  1497  	}
  1498  	if !strings.HasPrefix(gotLog, wantLog) {
  1499  		t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
  1500  	}
  1501  }
  1502  
  1503  func TestBidiStreamReverseProxy(t *testing.T) {
  1504  	setParallel(t)
  1505  	defer afterTest(t)
  1506  	backend := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1507  		if _, err := io.Copy(w, r.Body); err != nil {
  1508  			log.Printf("bidi backend copy: %v", err)
  1509  		}
  1510  	}))
  1511  	defer backend.close()
  1512  
  1513  	backURL, err := url.Parse(backend.ts.URL)
  1514  	if err != nil {
  1515  		t.Fatal(err)
  1516  	}
  1517  	rp := httputil.NewSingleHostReverseProxy(backURL)
  1518  	rp.Transport = backend.tr
  1519  	proxy := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1520  		rp.ServeHTTP(w, r)
  1521  	}))
  1522  	defer proxy.close()
  1523  
  1524  	bodyRes := make(chan any, 1) // error or hash.Hash
  1525  	pr, pw := io.Pipe()
  1526  	req, _ := NewRequest("PUT", proxy.ts.URL, pr)
  1527  	const size = 4 << 20
  1528  	go func() {
  1529  		h := sha1.New()
  1530  		_, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
  1531  		go pw.Close()
  1532  		if err != nil {
  1533  			bodyRes <- err
  1534  		} else {
  1535  			bodyRes <- h
  1536  		}
  1537  	}()
  1538  	res, err := backend.c.Do(req)
  1539  	if err != nil {
  1540  		t.Fatal(err)
  1541  	}
  1542  	defer res.Body.Close()
  1543  	hgot := sha1.New()
  1544  	n, err := io.Copy(hgot, res.Body)
  1545  	if err != nil {
  1546  		t.Fatal(err)
  1547  	}
  1548  	if n != size {
  1549  		t.Fatalf("got %d bytes; want %d", n, size)
  1550  	}
  1551  	select {
  1552  	case v := <-bodyRes:
  1553  		switch v := v.(type) {
  1554  		default:
  1555  			t.Fatalf("body copy: %v", err)
  1556  		case hash.Hash:
  1557  			if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
  1558  				t.Errorf("written bytes didn't match received bytes")
  1559  			}
  1560  		}
  1561  	case <-time.After(10 * time.Second):
  1562  		t.Fatal("timeout")
  1563  	}
  1564  
  1565  }
  1566  
  1567  // Always use HTTP/1.1 for WebSocket upgrades.
  1568  func TestH12_WebSocketUpgrade(t *testing.T) {
  1569  	h12Compare{
  1570  		Handler: func(w ResponseWriter, r *Request) {
  1571  			h := w.Header()
  1572  			h.Set("Foo", "bar")
  1573  		},
  1574  		ReqFunc: func(c *Client, url string) (*Response, error) {
  1575  			req, _ := NewRequest("GET", url, nil)
  1576  			req.Header.Set("Connection", "Upgrade")
  1577  			req.Header.Set("Upgrade", "WebSocket")
  1578  			return c.Do(req)
  1579  		},
  1580  		EarlyCheckResponse: func(proto string, res *Response) {
  1581  			if res.Proto != "HTTP/1.1" {
  1582  				t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
  1583  			}
  1584  			res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0
  1585  		},
  1586  	}.run(t)
  1587  }
  1588  
  1589  func TestIdentityTransferEncoding_h1(t *testing.T) { testIdentityTransferEncoding(t, h1Mode) }
  1590  func TestIdentityTransferEncoding_h2(t *testing.T) { testIdentityTransferEncoding(t, h2Mode) }
  1591  
  1592  func testIdentityTransferEncoding(t *testing.T, h2 bool) {
  1593  	setParallel(t)
  1594  	defer afterTest(t)
  1595  
  1596  	const body = "body"
  1597  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1598  		gotBody, _ := io.ReadAll(r.Body)
  1599  		if got, want := string(gotBody), body; got != want {
  1600  			t.Errorf("got request body = %q; want %q", got, want)
  1601  		}
  1602  		w.Header().Set("Transfer-Encoding", "identity")
  1603  		w.WriteHeader(StatusOK)
  1604  		w.(Flusher).Flush()
  1605  		io.WriteString(w, body)
  1606  	}))
  1607  	defer cst.close()
  1608  	req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
  1609  	res, err := cst.c.Do(req)
  1610  	if err != nil {
  1611  		t.Fatal(err)
  1612  	}
  1613  	defer res.Body.Close()
  1614  	gotBody, err := io.ReadAll(res.Body)
  1615  	if err != nil {
  1616  		t.Fatal(err)
  1617  	}
  1618  	if got, want := string(gotBody), body; got != want {
  1619  		t.Errorf("got response body = %q; want %q", got, want)
  1620  	}
  1621  }
  1622  
  1623  func TestEarlyHintsRequest_h1(t *testing.T) { testEarlyHintsRequest(t, h1Mode) }
  1624  func TestEarlyHintsRequest_h2(t *testing.T) { testEarlyHintsRequest(t, h2Mode) }
  1625  func testEarlyHintsRequest(t *testing.T, h2 bool) {
  1626  	defer afterTest(t)
  1627  
  1628  	var wg sync.WaitGroup
  1629  	wg.Add(1)
  1630  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1631  		h := w.Header()
  1632  
  1633  		h.Add("Content-Length", "123") // must be ignored
  1634  		h.Add("Link", "</style.css>; rel=preload; as=style")
  1635  		h.Add("Link", "</script.js>; rel=preload; as=script")
  1636  		w.WriteHeader(StatusEarlyHints)
  1637  
  1638  		wg.Wait()
  1639  
  1640  		h.Add("Link", "</foo.js>; rel=preload; as=script")
  1641  		w.WriteHeader(StatusEarlyHints)
  1642  
  1643  		w.Write([]byte("Hello"))
  1644  	}))
  1645  	defer cst.close()
  1646  
  1647  	checkLinkHeaders := func(t *testing.T, expected, got []string) {
  1648  		t.Helper()
  1649  
  1650  		if len(expected) != len(got) {
  1651  			t.Errorf("got %d expected %d", len(got), len(expected))
  1652  		}
  1653  
  1654  		for i := range expected {
  1655  			if expected[i] != got[i] {
  1656  				t.Errorf("got %q expected %q", got[i], expected[i])
  1657  			}
  1658  		}
  1659  	}
  1660  
  1661  	checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
  1662  		t.Helper()
  1663  
  1664  		for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
  1665  			if v, ok := header[h]; ok {
  1666  				t.Errorf("%s is %q; must not be sent", h, v)
  1667  			}
  1668  		}
  1669  	}
  1670  
  1671  	var respCounter uint8
  1672  	trace := &httptrace.ClientTrace{
  1673  		Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
  1674  			switch respCounter {
  1675  			case 0:
  1676  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
  1677  				checkExcludedHeaders(t, header)
  1678  
  1679  				wg.Done()
  1680  			case 1:
  1681  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
  1682  				checkExcludedHeaders(t, header)
  1683  
  1684  			default:
  1685  				t.Error("Unexpected 1xx response")
  1686  			}
  1687  
  1688  			respCounter++
  1689  
  1690  			return nil
  1691  		},
  1692  	}
  1693  	req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
  1694  
  1695  	res, err := cst.c.Do(req)
  1696  	if err != nil {
  1697  		t.Fatal(err)
  1698  	}
  1699  	defer res.Body.Close()
  1700  
  1701  	checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
  1702  	if cl := res.Header.Get("Content-Length"); cl != "123" {
  1703  		t.Errorf("Content-Length is %q; want 123", cl)
  1704  	}
  1705  
  1706  	body, _ := io.ReadAll(res.Body)
  1707  	if string(body) != "Hello" {
  1708  		t.Errorf("Read body %q; want Hello", body)
  1709  	}
  1710  }
  1711  

View as plain text