Source file
src/net/http/clientserver_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bytes"
11 "compress/gzip"
12 "context"
13 "crypto/rand"
14 "crypto/sha1"
15 "crypto/tls"
16 "fmt"
17 "hash"
18 "io"
19 "log"
20 "net"
21 . "net/http"
22 "net/http/httptest"
23 "net/http/httptrace"
24 "net/http/httputil"
25 "net/textproto"
26 "net/url"
27 "os"
28 "reflect"
29 "runtime"
30 "sort"
31 "strings"
32 "sync"
33 "sync/atomic"
34 "testing"
35 "time"
36 )
37
38 type testMode string
39
40 const (
41 http1Mode = testMode("h1")
42 https1Mode = testMode("https1")
43 http2Mode = testMode("h2")
44 )
45
46 type testNotParallelOpt struct{}
47
48 var (
49 testNotParallel = testNotParallelOpt{}
50 )
51
52 type TBRun[T any] interface {
53 testing.TB
54 Run(string, func(T)) bool
55 }
56
57
58
59
60
61
62
63
64 func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
65 t.Helper()
66 modes := []testMode{http1Mode, http2Mode}
67 parallel := true
68 for _, opt := range opts {
69 switch opt := opt.(type) {
70 case []testMode:
71 modes = opt
72 case testNotParallelOpt:
73 parallel = false
74 default:
75 t.Fatalf("unknown option type %T", opt)
76 }
77 }
78 if t, ok := any(t).(*testing.T); ok && parallel {
79 setParallel(t)
80 }
81 for _, mode := range modes {
82 t.Run(string(mode), func(t T) {
83 t.Helper()
84 if t, ok := any(t).(*testing.T); ok && parallel {
85 setParallel(t)
86 }
87 t.Cleanup(func() {
88 afterTest(t)
89 })
90 f(t, mode)
91 })
92 }
93 }
94
95 type clientServerTest struct {
96 t testing.TB
97 h2 bool
98 h Handler
99 ts *httptest.Server
100 tr *Transport
101 c *Client
102 }
103
104 func (t *clientServerTest) close() {
105 t.tr.CloseIdleConnections()
106 t.ts.Close()
107 }
108
109 func (t *clientServerTest) getURL(u string) string {
110 res, err := t.c.Get(u)
111 if err != nil {
112 t.t.Fatal(err)
113 }
114 defer res.Body.Close()
115 slurp, err := io.ReadAll(res.Body)
116 if err != nil {
117 t.t.Fatal(err)
118 }
119 return string(slurp)
120 }
121
122 func (t *clientServerTest) scheme() string {
123 if t.h2 {
124 return "https"
125 }
126 return "http"
127 }
128
129 var optQuietLog = func(ts *httptest.Server) {
130 ts.Config.ErrorLog = quietLog
131 }
132
133 func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
134 return func(ts *httptest.Server) {
135 ts.Config.ErrorLog = lg
136 }
137 }
138
139
140
141
142
143
144
145
146
147
148
149
150 func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
151 if mode == http2Mode {
152 CondSkipHTTP2(t)
153 }
154 cst := &clientServerTest{
155 t: t,
156 h2: mode == http2Mode,
157 h: h,
158 }
159 cst.ts = httptest.NewUnstartedServer(h)
160
161 var transportFuncs []func(*Transport)
162 for _, opt := range opts {
163 switch opt := opt.(type) {
164 case func(*Transport):
165 transportFuncs = append(transportFuncs, opt)
166 case func(*httptest.Server):
167 opt(cst.ts)
168 default:
169 t.Fatalf("unhandled option type %T", opt)
170 }
171 }
172
173 if cst.ts.Config.ErrorLog == nil {
174 cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
175 }
176
177 switch mode {
178 case http1Mode:
179 cst.ts.Start()
180 case https1Mode:
181 cst.ts.StartTLS()
182 case http2Mode:
183 ExportHttp2ConfigureServer(cst.ts.Config, nil)
184 cst.ts.TLS = cst.ts.Config.TLSConfig
185 cst.ts.StartTLS()
186 default:
187 t.Fatalf("unknown test mode %v", mode)
188 }
189 cst.c = cst.ts.Client()
190 cst.tr = cst.c.Transport.(*Transport)
191 if mode == http2Mode {
192 if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
193 t.Fatal(err)
194 }
195 }
196 for _, f := range transportFuncs {
197 f(cst.tr)
198 }
199 t.Cleanup(func() {
200 cst.close()
201 })
202 return cst
203 }
204
205 type testLogWriter struct {
206 t testing.TB
207 }
208
209 func (w testLogWriter) Write(b []byte) (int, error) {
210 w.t.Logf("server log: %v", strings.TrimSpace(string(b)))
211 return len(b), nil
212 }
213
214
215 func TestNewClientServerTest(t *testing.T) {
216 run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode})
217 }
218 func testNewClientServerTest(t *testing.T, mode testMode) {
219 var got struct {
220 sync.Mutex
221 proto string
222 hasTLS bool
223 }
224 h := HandlerFunc(func(w ResponseWriter, r *Request) {
225 got.Lock()
226 defer got.Unlock()
227 got.proto = r.Proto
228 got.hasTLS = r.TLS != nil
229 })
230 cst := newClientServerTest(t, mode, h)
231 if _, err := cst.c.Head(cst.ts.URL); err != nil {
232 t.Fatal(err)
233 }
234 var wantProto string
235 var wantTLS bool
236 switch mode {
237 case http1Mode:
238 wantProto = "HTTP/1.1"
239 wantTLS = false
240 case https1Mode:
241 wantProto = "HTTP/1.1"
242 wantTLS = true
243 case http2Mode:
244 wantProto = "HTTP/2.0"
245 wantTLS = true
246 }
247 if got.proto != wantProto {
248 t.Errorf("req.Proto = %q, want %q", got.proto, wantProto)
249 }
250 if got.hasTLS != wantTLS {
251 t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS)
252 }
253 }
254
255 func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) }
256 func testChunkedResponseHeaders(t *testing.T, mode testMode) {
257 log.SetOutput(io.Discard)
258 defer log.SetOutput(os.Stderr)
259 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
260 w.Header().Set("Content-Length", "intentional gibberish")
261 w.(Flusher).Flush()
262 fmt.Fprintf(w, "I am a chunked response.")
263 }))
264
265 res, err := cst.c.Get(cst.ts.URL)
266 if err != nil {
267 t.Fatalf("Get error: %v", err)
268 }
269 defer res.Body.Close()
270 if g, e := res.ContentLength, int64(-1); g != e {
271 t.Errorf("expected ContentLength of %d; got %d", e, g)
272 }
273 wantTE := []string{"chunked"}
274 if mode == http2Mode {
275 wantTE = nil
276 }
277 if !reflect.DeepEqual(res.TransferEncoding, wantTE) {
278 t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
279 }
280 if got, haveCL := res.Header["Content-Length"]; haveCL {
281 t.Errorf("Unexpected Content-Length: %q", got)
282 }
283 }
284
285 type reqFunc func(c *Client, url string) (*Response, error)
286
287
288
289 type h12Compare struct {
290 Handler func(ResponseWriter, *Request)
291 ReqFunc reqFunc
292 CheckResponse func(proto string, res *Response)
293 EarlyCheckResponse func(proto string, res *Response)
294 Opts []any
295 }
296
297 func (tt h12Compare) reqFunc() reqFunc {
298 if tt.ReqFunc == nil {
299 return (*Client).Get
300 }
301 return tt.ReqFunc
302 }
303
304 func (tt h12Compare) run(t *testing.T) {
305 setParallel(t)
306 cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...)
307 defer cst1.close()
308 cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...)
309 defer cst2.close()
310
311 res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
312 if err != nil {
313 t.Errorf("HTTP/1 request: %v", err)
314 return
315 }
316 res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
317 if err != nil {
318 t.Errorf("HTTP/2 request: %v", err)
319 return
320 }
321
322 if fn := tt.EarlyCheckResponse; fn != nil {
323 fn("HTTP/1.1", res1)
324 fn("HTTP/2.0", res2)
325 }
326
327 tt.normalizeRes(t, res1, "HTTP/1.1")
328 tt.normalizeRes(t, res2, "HTTP/2.0")
329 res1body, res2body := res1.Body, res2.Body
330
331 eres1 := mostlyCopy(res1)
332 eres2 := mostlyCopy(res2)
333 if !reflect.DeepEqual(eres1, eres2) {
334 t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
335 cst1.ts.URL, eres1, cst2.ts.URL, eres2)
336 }
337 if !reflect.DeepEqual(res1body, res2body) {
338 t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
339 }
340 if fn := tt.CheckResponse; fn != nil {
341 res1.Body, res2.Body = res1body, res2body
342 fn("HTTP/1.1", res1)
343 fn("HTTP/2.0", res2)
344 }
345 }
346
347 func mostlyCopy(r *Response) *Response {
348 c := *r
349 c.Body = nil
350 c.TransferEncoding = nil
351 c.TLS = nil
352 c.Request = nil
353 return &c
354 }
355
356 type slurpResult struct {
357 io.ReadCloser
358 body []byte
359 err error
360 }
361
362 func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
363
364 func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
365 if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
366 res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
367 } else {
368 t.Errorf("got %q response; want %q", res.Proto, wantProto)
369 }
370 slurp, err := io.ReadAll(res.Body)
371
372 res.Body.Close()
373 res.Body = slurpResult{
374 ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
375 body: slurp,
376 err: err,
377 }
378 for i, v := range res.Header["Date"] {
379 res.Header["Date"][i] = strings.Repeat("x", len(v))
380 }
381 if res.Request == nil {
382 t.Errorf("for %s, no request", wantProto)
383 }
384 if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
385 t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
386 }
387 }
388
389
390 func TestH12_HeadContentLengthNoBody(t *testing.T) {
391 h12Compare{
392 ReqFunc: (*Client).Head,
393 Handler: func(w ResponseWriter, r *Request) {
394 },
395 }.run(t)
396 }
397
398 func TestH12_HeadContentLengthSmallBody(t *testing.T) {
399 h12Compare{
400 ReqFunc: (*Client).Head,
401 Handler: func(w ResponseWriter, r *Request) {
402 io.WriteString(w, "small")
403 },
404 }.run(t)
405 }
406
407 func TestH12_HeadContentLengthLargeBody(t *testing.T) {
408 h12Compare{
409 ReqFunc: (*Client).Head,
410 Handler: func(w ResponseWriter, r *Request) {
411 chunk := strings.Repeat("x", 512<<10)
412 for i := 0; i < 10; i++ {
413 io.WriteString(w, chunk)
414 }
415 },
416 }.run(t)
417 }
418
419 func TestH12_200NoBody(t *testing.T) {
420 h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
421 }
422
423 func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
424 func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
425 func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
426
427 func testH12_noBody(t *testing.T, status int) {
428 h12Compare{Handler: func(w ResponseWriter, r *Request) {
429 w.WriteHeader(status)
430 }}.run(t)
431 }
432
433 func TestH12_SmallBody(t *testing.T) {
434 h12Compare{Handler: func(w ResponseWriter, r *Request) {
435 io.WriteString(w, "small body")
436 }}.run(t)
437 }
438
439 func TestH12_ExplicitContentLength(t *testing.T) {
440 h12Compare{Handler: func(w ResponseWriter, r *Request) {
441 w.Header().Set("Content-Length", "3")
442 io.WriteString(w, "foo")
443 }}.run(t)
444 }
445
446 func TestH12_FlushBeforeBody(t *testing.T) {
447 h12Compare{Handler: func(w ResponseWriter, r *Request) {
448 w.(Flusher).Flush()
449 io.WriteString(w, "foo")
450 }}.run(t)
451 }
452
453 func TestH12_FlushMidBody(t *testing.T) {
454 h12Compare{Handler: func(w ResponseWriter, r *Request) {
455 io.WriteString(w, "foo")
456 w.(Flusher).Flush()
457 io.WriteString(w, "bar")
458 }}.run(t)
459 }
460
461 func TestH12_Head_ExplicitLen(t *testing.T) {
462 h12Compare{
463 ReqFunc: (*Client).Head,
464 Handler: func(w ResponseWriter, r *Request) {
465 if r.Method != "HEAD" {
466 t.Errorf("unexpected method %q", r.Method)
467 }
468 w.Header().Set("Content-Length", "1235")
469 },
470 }.run(t)
471 }
472
473 func TestH12_Head_ImplicitLen(t *testing.T) {
474 h12Compare{
475 ReqFunc: (*Client).Head,
476 Handler: func(w ResponseWriter, r *Request) {
477 if r.Method != "HEAD" {
478 t.Errorf("unexpected method %q", r.Method)
479 }
480 io.WriteString(w, "foo")
481 },
482 }.run(t)
483 }
484
485 func TestH12_HandlerWritesTooLittle(t *testing.T) {
486 h12Compare{
487 Handler: func(w ResponseWriter, r *Request) {
488 w.Header().Set("Content-Length", "3")
489 io.WriteString(w, "12")
490 },
491 CheckResponse: func(proto string, res *Response) {
492 sr, ok := res.Body.(slurpResult)
493 if !ok {
494 t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
495 return
496 }
497 if sr.err != io.ErrUnexpectedEOF {
498 t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
499 }
500 if string(sr.body) != "12" {
501 t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
502 }
503 },
504 }.run(t)
505 }
506
507
508
509
510
511
512
513 func TestH12_HandlerWritesTooMuch(t *testing.T) {
514 h12Compare{
515 Handler: func(w ResponseWriter, r *Request) {
516 w.Header().Set("Content-Length", "3")
517 w.(Flusher).Flush()
518 io.WriteString(w, "123")
519 w.(Flusher).Flush()
520 n, err := io.WriteString(w, "x")
521 if n > 0 || err == nil {
522 t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err)
523 }
524 },
525 }.run(t)
526 }
527
528
529
530 func TestH12_AutoGzip(t *testing.T) {
531 h12Compare{
532 Handler: func(w ResponseWriter, r *Request) {
533 if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
534 t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
535 }
536 w.Header().Set("Content-Encoding", "gzip")
537 gz := gzip.NewWriter(w)
538 io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
539 gz.Close()
540 },
541 }.run(t)
542 }
543
544 func TestH12_AutoGzip_Disabled(t *testing.T) {
545 h12Compare{
546 Opts: []any{
547 func(tr *Transport) { tr.DisableCompression = true },
548 },
549 Handler: func(w ResponseWriter, r *Request) {
550 fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
551 if ae := r.Header.Get("Accept-Encoding"); ae != "" {
552 t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
553 }
554 },
555 }.run(t)
556 }
557
558
559
560
561 func Test304Responses(t *testing.T) { run(t, test304Responses) }
562 func test304Responses(t *testing.T, mode testMode) {
563 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
564 w.WriteHeader(StatusNotModified)
565 _, err := w.Write([]byte("illegal body"))
566 if err != ErrBodyNotAllowed {
567 t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
568 }
569 }))
570 defer cst.close()
571 res, err := cst.c.Get(cst.ts.URL)
572 if err != nil {
573 t.Fatal(err)
574 }
575 if len(res.TransferEncoding) > 0 {
576 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
577 }
578 body, err := io.ReadAll(res.Body)
579 if err != nil {
580 t.Error(err)
581 }
582 if len(body) > 0 {
583 t.Errorf("got unexpected body %q", string(body))
584 }
585 }
586
587 func TestH12_ServerEmptyContentLength(t *testing.T) {
588 h12Compare{
589 Handler: func(w ResponseWriter, r *Request) {
590 w.Header()["Content-Type"] = []string{""}
591 io.WriteString(w, "<html><body>hi</body></html>")
592 },
593 }.run(t)
594 }
595
596 func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
597 h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
598 }
599
600 func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
601 h12requestContentLength(t, func() io.Reader { return nil }, 0)
602 }
603
604 func TestH12_RequestContentLength_Unknown(t *testing.T) {
605 h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
606 }
607
608 func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
609 h12Compare{
610 Handler: func(w ResponseWriter, r *Request) {
611 w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
612 fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
613 },
614 ReqFunc: func(c *Client, url string) (*Response, error) {
615 return c.Post(url, "text/plain", bodyfn())
616 },
617 CheckResponse: func(proto string, res *Response) {
618 if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
619 t.Errorf("Proto %q got length %q; want %q", proto, got, want)
620 }
621 },
622 }.run(t)
623 }
624
625
626
627 func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) }
628 func testCancelRequestMidBody(t *testing.T, mode testMode) {
629 unblock := make(chan bool)
630 didFlush := make(chan bool, 1)
631 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
632 io.WriteString(w, "Hello")
633 w.(Flusher).Flush()
634 didFlush <- true
635 <-unblock
636 io.WriteString(w, ", world.")
637 }))
638 defer close(unblock)
639
640 req, _ := NewRequest("GET", cst.ts.URL, nil)
641 cancel := make(chan struct{})
642 req.Cancel = cancel
643
644 res, err := cst.c.Do(req)
645 if err != nil {
646 t.Fatal(err)
647 }
648 defer res.Body.Close()
649 <-didFlush
650
651
652
653 firstRead := make([]byte, 10)
654 n, err := res.Body.Read(firstRead)
655 if err != nil {
656 t.Fatal(err)
657 }
658 firstRead = firstRead[:n]
659
660 close(cancel)
661
662 rest, err := io.ReadAll(res.Body)
663 all := string(firstRead) + string(rest)
664 if all != "Hello" {
665 t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
666 }
667 if err != ExportErrRequestCanceled {
668 t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
669 }
670 }
671
672
673 func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) }
674 func testTrailersClientToServer(t *testing.T, mode testMode) {
675 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
676 var decl []string
677 for k := range r.Trailer {
678 decl = append(decl, k)
679 }
680 sort.Strings(decl)
681
682 slurp, err := io.ReadAll(r.Body)
683 if err != nil {
684 t.Errorf("Server reading request body: %v", err)
685 }
686 if string(slurp) != "foo" {
687 t.Errorf("Server read request body %q; want foo", slurp)
688 }
689 if r.Trailer == nil {
690 io.WriteString(w, "nil Trailer")
691 } else {
692 fmt.Fprintf(w, "decl: %v, vals: %s, %s",
693 decl,
694 r.Trailer.Get("Client-Trailer-A"),
695 r.Trailer.Get("Client-Trailer-B"))
696 }
697 }))
698
699 var req *Request
700 req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
701 eofReaderFunc(func() {
702 req.Trailer["Client-Trailer-A"] = []string{"valuea"}
703 }),
704 strings.NewReader("foo"),
705 eofReaderFunc(func() {
706 req.Trailer["Client-Trailer-B"] = []string{"valueb"}
707 }),
708 ))
709 req.Trailer = Header{
710 "Client-Trailer-A": nil,
711 "Client-Trailer-B": nil,
712 }
713 req.ContentLength = -1
714 res, err := cst.c.Do(req)
715 if err != nil {
716 t.Fatal(err)
717 }
718 if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
719 t.Error(err)
720 }
721 }
722
723
724 func TestTrailersServerToClient(t *testing.T) {
725 run(t, func(t *testing.T, mode testMode) {
726 testTrailersServerToClient(t, mode, false)
727 })
728 }
729 func TestTrailersServerToClientFlush(t *testing.T) {
730 run(t, func(t *testing.T, mode testMode) {
731 testTrailersServerToClient(t, mode, true)
732 })
733 }
734
735 func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) {
736 const body = "Some body"
737 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
738 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
739 w.Header().Add("Trailer", "Server-Trailer-C")
740
741 io.WriteString(w, body)
742 if flush {
743 w.(Flusher).Flush()
744 }
745
746
747
748
749
750 w.Header().Set("Server-Trailer-A", "valuea")
751 w.Header().Set("Server-Trailer-C", "valuec")
752 w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
753 }))
754
755 res, err := cst.c.Get(cst.ts.URL)
756 if err != nil {
757 t.Fatal(err)
758 }
759
760 wantHeader := Header{
761 "Content-Type": {"text/plain; charset=utf-8"},
762 }
763 wantLen := -1
764 if mode == http2Mode && !flush {
765
766
767
768
769
770 wantLen = len(body)
771 wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
772 }
773 if res.ContentLength != int64(wantLen) {
774 t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
775 }
776
777 delete(res.Header, "Date")
778 if !reflect.DeepEqual(res.Header, wantHeader) {
779 t.Errorf("Header = %v; want %v", res.Header, wantHeader)
780 }
781
782 if got, want := res.Trailer, (Header{
783 "Server-Trailer-A": nil,
784 "Server-Trailer-B": nil,
785 "Server-Trailer-C": nil,
786 }); !reflect.DeepEqual(got, want) {
787 t.Errorf("Trailer before body read = %v; want %v", got, want)
788 }
789
790 if err := wantBody(res, nil, body); err != nil {
791 t.Fatal(err)
792 }
793
794 if got, want := res.Trailer, (Header{
795 "Server-Trailer-A": {"valuea"},
796 "Server-Trailer-B": nil,
797 "Server-Trailer-C": {"valuec"},
798 }); !reflect.DeepEqual(got, want) {
799 t.Errorf("Trailer after body read = %v; want %v", got, want)
800 }
801 }
802
803
804 func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) }
805 func testResponseBodyReadAfterClose(t *testing.T, mode testMode) {
806 const body = "Some body"
807 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
808 io.WriteString(w, body)
809 }))
810 res, err := cst.c.Get(cst.ts.URL)
811 if err != nil {
812 t.Fatal(err)
813 }
814 res.Body.Close()
815 data, err := io.ReadAll(res.Body)
816 if len(data) != 0 || err == nil {
817 t.Fatalf("ReadAll returned %q, %v; want error", data, err)
818 }
819 }
820
821 func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) }
822 func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) {
823 const reqBody = "some request body"
824 const resBody = "some response body"
825 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
826 var wg sync.WaitGroup
827 wg.Add(2)
828 didRead := make(chan bool, 1)
829
830 go func() {
831 defer wg.Done()
832 data, err := io.ReadAll(r.Body)
833 if string(data) != reqBody {
834 t.Errorf("Handler read %q; want %q", data, reqBody)
835 }
836 if err != nil {
837 t.Errorf("Handler Read: %v", err)
838 }
839 didRead <- true
840 }()
841
842 go func() {
843 defer wg.Done()
844 if mode != http2Mode {
845
846
847
848
849 <-didRead
850 }
851 io.WriteString(w, resBody)
852 }()
853 wg.Wait()
854 }))
855 req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
856 req.Header.Add("Expect", "100-continue")
857 res, err := cst.c.Do(req)
858 if err != nil {
859 t.Fatal(err)
860 }
861 data, err := io.ReadAll(res.Body)
862 defer res.Body.Close()
863 if err != nil {
864 t.Fatal(err)
865 }
866 if string(data) != resBody {
867 t.Errorf("read %q; want %q", data, resBody)
868 }
869 }
870
871 func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) }
872 func testConnectRequest(t *testing.T, mode testMode) {
873 gotc := make(chan *Request, 1)
874 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
875 gotc <- r
876 }))
877
878 u, err := url.Parse(cst.ts.URL)
879 if err != nil {
880 t.Fatal(err)
881 }
882
883 tests := []struct {
884 req *Request
885 want string
886 }{
887 {
888 req: &Request{
889 Method: "CONNECT",
890 Header: Header{},
891 URL: u,
892 },
893 want: u.Host,
894 },
895 {
896 req: &Request{
897 Method: "CONNECT",
898 Header: Header{},
899 URL: u,
900 Host: "example.com:123",
901 },
902 want: "example.com:123",
903 },
904 }
905
906 for i, tt := range tests {
907 res, err := cst.c.Do(tt.req)
908 if err != nil {
909 t.Errorf("%d. RoundTrip = %v", i, err)
910 continue
911 }
912 res.Body.Close()
913 req := <-gotc
914 if req.Method != "CONNECT" {
915 t.Errorf("method = %q; want CONNECT", req.Method)
916 }
917 if req.Host != tt.want {
918 t.Errorf("Host = %q; want %q", req.Host, tt.want)
919 }
920 if req.URL.Host != tt.want {
921 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
922 }
923 }
924 }
925
926 func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) }
927 func testTransportUserAgent(t *testing.T, mode testMode) {
928 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
929 fmt.Fprintf(w, "%q", r.Header["User-Agent"])
930 }))
931
932 either := func(a, b string) string {
933 if mode == http2Mode {
934 return b
935 }
936 return a
937 }
938
939 tests := []struct {
940 setup func(*Request)
941 want string
942 }{
943 {
944 func(r *Request) {},
945 either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
946 },
947 {
948 func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
949 `["foo/1.2.3"]`,
950 },
951 {
952 func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
953 `["single"]`,
954 },
955 {
956 func(r *Request) { r.Header.Set("User-Agent", "") },
957 `[]`,
958 },
959 {
960 func(r *Request) { r.Header["User-Agent"] = nil },
961 `[]`,
962 },
963 }
964 for i, tt := range tests {
965 req, _ := NewRequest("GET", cst.ts.URL, nil)
966 tt.setup(req)
967 res, err := cst.c.Do(req)
968 if err != nil {
969 t.Errorf("%d. RoundTrip = %v", i, err)
970 continue
971 }
972 slurp, err := io.ReadAll(res.Body)
973 res.Body.Close()
974 if err != nil {
975 t.Errorf("%d. read body = %v", i, err)
976 continue
977 }
978 if string(slurp) != tt.want {
979 t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
980 }
981 }
982 }
983
984 func TestStarRequestMethod(t *testing.T) {
985 for _, method := range []string{"FOO", "OPTIONS"} {
986 t.Run(method, func(t *testing.T) {
987 run(t, func(t *testing.T, mode testMode) {
988 testStarRequest(t, method, mode)
989 })
990 })
991 }
992 }
993 func testStarRequest(t *testing.T, method string, mode testMode) {
994 gotc := make(chan *Request, 1)
995 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
996 w.Header().Set("foo", "bar")
997 gotc <- r
998 w.(Flusher).Flush()
999 }))
1000
1001 u, err := url.Parse(cst.ts.URL)
1002 if err != nil {
1003 t.Fatal(err)
1004 }
1005 u.Path = "*"
1006
1007 req := &Request{
1008 Method: method,
1009 Header: Header{},
1010 URL: u,
1011 }
1012
1013 res, err := cst.c.Do(req)
1014 if err != nil {
1015 t.Fatalf("RoundTrip = %v", err)
1016 }
1017 res.Body.Close()
1018
1019 wantFoo := "bar"
1020 wantLen := int64(-1)
1021 if method == "OPTIONS" {
1022 wantFoo = ""
1023 wantLen = 0
1024 }
1025 if res.StatusCode != 200 {
1026 t.Errorf("status code = %v; want %d", res.Status, 200)
1027 }
1028 if res.ContentLength != wantLen {
1029 t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
1030 }
1031 if got := res.Header.Get("foo"); got != wantFoo {
1032 t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
1033 }
1034 select {
1035 case req = <-gotc:
1036 default:
1037 req = nil
1038 }
1039 if req == nil {
1040 if method != "OPTIONS" {
1041 t.Fatalf("handler never got request")
1042 }
1043 return
1044 }
1045 if req.Method != method {
1046 t.Errorf("method = %q; want %q", req.Method, method)
1047 }
1048 if req.URL.Path != "*" {
1049 t.Errorf("URL.Path = %q; want *", req.URL.Path)
1050 }
1051 if req.RequestURI != "*" {
1052 t.Errorf("RequestURI = %q; want *", req.RequestURI)
1053 }
1054 }
1055
1056
1057 func TestTransportDiscardsUnneededConns(t *testing.T) {
1058 run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode})
1059 }
1060 func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) {
1061 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1062 fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
1063 }))
1064 defer cst.close()
1065
1066 var numOpen, numClose int32
1067
1068 tlsConfig := &tls.Config{InsecureSkipVerify: true}
1069 tr := &Transport{
1070 TLSClientConfig: tlsConfig,
1071 DialTLS: func(_, addr string) (net.Conn, error) {
1072 time.Sleep(10 * time.Millisecond)
1073 rc, err := net.Dial("tcp", addr)
1074 if err != nil {
1075 return nil, err
1076 }
1077 atomic.AddInt32(&numOpen, 1)
1078 c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
1079 return tls.Client(c, tlsConfig), nil
1080 },
1081 }
1082 if err := ExportHttp2ConfigureTransport(tr); err != nil {
1083 t.Fatal(err)
1084 }
1085 defer tr.CloseIdleConnections()
1086
1087 c := &Client{Transport: tr}
1088
1089 const N = 10
1090 gotBody := make(chan string, N)
1091 var wg sync.WaitGroup
1092 for i := 0; i < N; i++ {
1093 wg.Add(1)
1094 go func() {
1095 defer wg.Done()
1096 resp, err := c.Get(cst.ts.URL)
1097 if err != nil {
1098
1099
1100 time.Sleep(10 * time.Millisecond)
1101 resp, err = c.Get(cst.ts.URL)
1102 if err != nil {
1103 t.Errorf("Get: %v", err)
1104 return
1105 }
1106 }
1107 defer resp.Body.Close()
1108 slurp, err := io.ReadAll(resp.Body)
1109 if err != nil {
1110 t.Error(err)
1111 }
1112 gotBody <- string(slurp)
1113 }()
1114 }
1115 wg.Wait()
1116 close(gotBody)
1117
1118 var last string
1119 for got := range gotBody {
1120 if last == "" {
1121 last = got
1122 continue
1123 }
1124 if got != last {
1125 t.Errorf("Response body changed: %q -> %q", last, got)
1126 }
1127 }
1128
1129 var open, close int32
1130 for i := 0; i < 150; i++ {
1131 open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
1132 if open < 1 {
1133 t.Fatalf("open = %d; want at least", open)
1134 }
1135 if close == open-1 {
1136
1137 return
1138 }
1139 time.Sleep(10 * time.Millisecond)
1140 }
1141 t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
1142 }
1143
1144
1145 func TestTransportGCRequest(t *testing.T) {
1146 run(t, func(t *testing.T, mode testMode) {
1147 t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) })
1148 t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) })
1149 })
1150 }
1151 func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
1152 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1153 io.ReadAll(r.Body)
1154 if body {
1155 io.WriteString(w, "Hello.")
1156 }
1157 }))
1158
1159 didGC := make(chan struct{})
1160 (func() {
1161 body := strings.NewReader("some body")
1162 req, _ := NewRequest("POST", cst.ts.URL, body)
1163 runtime.SetFinalizer(req, func(*Request) { close(didGC) })
1164 res, err := cst.c.Do(req)
1165 if err != nil {
1166 t.Fatal(err)
1167 }
1168 if _, err := io.ReadAll(res.Body); err != nil {
1169 t.Fatal(err)
1170 }
1171 if err := res.Body.Close(); err != nil {
1172 t.Fatal(err)
1173 }
1174 })()
1175 timeout := time.NewTimer(5 * time.Second)
1176 defer timeout.Stop()
1177 for {
1178 select {
1179 case <-didGC:
1180 return
1181 case <-time.After(100 * time.Millisecond):
1182 runtime.GC()
1183 case <-timeout.C:
1184 t.Fatal("never saw GC of request")
1185 }
1186 }
1187 }
1188
1189 func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) }
1190 func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) {
1191 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1192 fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
1193 }), optQuietLog)
1194 cst.tr.DisableKeepAlives = true
1195
1196 tests := []struct {
1197 key, val string
1198 ok bool
1199 }{
1200 {"Foo", "capital-key", true},
1201 {"Foo", "foo\x00bar", false},
1202 {"Foo", "two\nlines", false},
1203 {"bogus\nkey", "v", false},
1204 {"A space", "v", false},
1205 {"имя", "v", false},
1206 {"name", "валю", true},
1207 {"", "v", false},
1208 {"k", "", true},
1209 }
1210 for _, tt := range tests {
1211 dialedc := make(chan bool, 1)
1212 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
1213 dialedc <- true
1214 return net.Dial(netw, addr)
1215 }
1216 req, _ := NewRequest("GET", cst.ts.URL, nil)
1217 req.Header[tt.key] = []string{tt.val}
1218 res, err := cst.c.Do(req)
1219 var body []byte
1220 if err == nil {
1221 body, _ = io.ReadAll(res.Body)
1222 res.Body.Close()
1223 }
1224 var dialed bool
1225 select {
1226 case <-dialedc:
1227 dialed = true
1228 default:
1229 }
1230
1231 if !tt.ok && dialed {
1232 t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
1233 } else if (err == nil) != tt.ok {
1234 t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
1235 }
1236 }
1237 }
1238
1239 func TestInterruptWithPanic(t *testing.T) {
1240 run(t, func(t *testing.T, mode testMode) {
1241 t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") })
1242 t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) })
1243 t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) })
1244 }, testNotParallel)
1245 }
1246 func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) {
1247 const msg = "hello"
1248
1249 testDone := make(chan struct{})
1250 defer close(testDone)
1251
1252 var errorLog lockedBytesBuffer
1253 gotHeaders := make(chan bool, 1)
1254 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1255 io.WriteString(w, msg)
1256 w.(Flusher).Flush()
1257
1258 select {
1259 case <-gotHeaders:
1260 case <-testDone:
1261 }
1262 panic(panicValue)
1263 }), func(ts *httptest.Server) {
1264 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1265 })
1266 res, err := cst.c.Get(cst.ts.URL)
1267 if err != nil {
1268 t.Fatal(err)
1269 }
1270 gotHeaders <- true
1271 defer res.Body.Close()
1272 slurp, err := io.ReadAll(res.Body)
1273 if string(slurp) != msg {
1274 t.Errorf("client read %q; want %q", slurp, msg)
1275 }
1276 if err == nil {
1277 t.Errorf("client read all successfully; want some error")
1278 }
1279 logOutput := func() string {
1280 errorLog.Lock()
1281 defer errorLog.Unlock()
1282 return errorLog.String()
1283 }
1284 wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
1285
1286 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
1287 gotLog := logOutput()
1288 if !wantStackLogged {
1289 if gotLog == "" {
1290 return true
1291 }
1292 t.Fatalf("want no log output; got: %s", gotLog)
1293 }
1294 if gotLog == "" {
1295 if d > 0 {
1296 t.Logf("wanted a stack trace logged; got nothing after %v", d)
1297 }
1298 return false
1299 }
1300 if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
1301 if d > 0 {
1302 t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog)
1303 }
1304 return false
1305 }
1306 return true
1307 })
1308 }
1309
1310 type lockedBytesBuffer struct {
1311 sync.Mutex
1312 bytes.Buffer
1313 }
1314
1315 func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
1316 b.Lock()
1317 defer b.Unlock()
1318 return b.Buffer.Write(p)
1319 }
1320
1321
1322 func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
1323 h12Compare{
1324 Handler: func(w ResponseWriter, r *Request) {
1325 h := w.Header()
1326 h.Set("Content-Encoding", "gzip")
1327 h.Set("Content-Length", "23")
1328 io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
1329 },
1330 EarlyCheckResponse: func(proto string, res *Response) {
1331 if !res.Uncompressed {
1332 t.Errorf("%s: expected Uncompressed to be set", proto)
1333 }
1334 dump, err := httputil.DumpResponse(res, true)
1335 if err != nil {
1336 t.Errorf("%s: DumpResponse: %v", proto, err)
1337 return
1338 }
1339 if strings.Contains(string(dump), "Connection: close") {
1340 t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
1341 }
1342 if !strings.Contains(string(dump), "FOO") {
1343 t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
1344 }
1345 },
1346 }.run(t)
1347 }
1348
1349
1350 func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) }
1351 func testCloseIdleConnections(t *testing.T, mode testMode) {
1352 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1353 w.Header().Set("X-Addr", r.RemoteAddr)
1354 }))
1355 get := func() string {
1356 res, err := cst.c.Get(cst.ts.URL)
1357 if err != nil {
1358 t.Fatal(err)
1359 }
1360 res.Body.Close()
1361 v := res.Header.Get("X-Addr")
1362 if v == "" {
1363 t.Fatal("didn't get X-Addr")
1364 }
1365 return v
1366 }
1367 a1 := get()
1368 cst.tr.CloseIdleConnections()
1369 a2 := get()
1370 if a1 == a2 {
1371 t.Errorf("didn't close connection")
1372 }
1373 }
1374
1375 type noteCloseConn struct {
1376 net.Conn
1377 closeFunc func()
1378 }
1379
1380 func (x noteCloseConn) Close() error {
1381 x.closeFunc()
1382 return x.Conn.Close()
1383 }
1384
1385 type testErrorReader struct{ t *testing.T }
1386
1387 func (r testErrorReader) Read(p []byte) (n int, err error) {
1388 r.t.Error("unexpected Read call")
1389 return 0, io.EOF
1390 }
1391
1392 func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) }
1393 func testNoSniffExpectRequestBody(t *testing.T, mode testMode) {
1394 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1395 w.WriteHeader(StatusUnauthorized)
1396 }))
1397
1398
1399 cst.tr.ExpectContinueTimeout = 10 * time.Second
1400
1401 req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
1402 if err != nil {
1403 t.Fatal(err)
1404 }
1405 req.ContentLength = 0
1406 req.Header.Set("Expect", "100-continue")
1407 res, err := cst.tr.RoundTrip(req)
1408 if err != nil {
1409 t.Fatal(err)
1410 }
1411 defer res.Body.Close()
1412 if res.StatusCode != StatusUnauthorized {
1413 t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
1414 }
1415 }
1416
1417 func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) }
1418 func testServerUndeclaredTrailers(t *testing.T, mode testMode) {
1419 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1420 w.Header().Set("Foo", "Bar")
1421 w.Header().Set("Trailer:Foo", "Baz")
1422 w.(Flusher).Flush()
1423 w.Header().Add("Trailer:Foo", "Baz2")
1424 w.Header().Set("Trailer:Bar", "Quux")
1425 }))
1426 res, err := cst.c.Get(cst.ts.URL)
1427 if err != nil {
1428 t.Fatal(err)
1429 }
1430 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1431 t.Fatal(err)
1432 }
1433 res.Body.Close()
1434 delete(res.Header, "Date")
1435 delete(res.Header, "Content-Type")
1436
1437 if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
1438 t.Errorf("Header = %#v; want %#v", res.Header, want)
1439 }
1440 if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
1441 t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
1442 }
1443 }
1444
1445 func TestBadResponseAfterReadingBody(t *testing.T) {
1446 run(t, testBadResponseAfterReadingBody, []testMode{http1Mode})
1447 }
1448 func testBadResponseAfterReadingBody(t *testing.T, mode testMode) {
1449 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1450 _, err := io.Copy(io.Discard, r.Body)
1451 if err != nil {
1452 t.Fatal(err)
1453 }
1454 c, _, err := w.(Hijacker).Hijack()
1455 if err != nil {
1456 t.Fatal(err)
1457 }
1458 defer c.Close()
1459 fmt.Fprintln(c, "some bogus crap")
1460 }))
1461
1462 closes := 0
1463 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
1464 if err == nil {
1465 res.Body.Close()
1466 t.Fatal("expected an error to be returned from Post")
1467 }
1468 if closes != 1 {
1469 t.Errorf("closes = %d; want 1", closes)
1470 }
1471 }
1472
1473 func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) }
1474 func testWriteHeader0(t *testing.T, mode testMode) {
1475 gotpanic := make(chan bool, 1)
1476 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1477 defer close(gotpanic)
1478 defer func() {
1479 if e := recover(); e != nil {
1480 got := fmt.Sprintf("%T, %v", e, e)
1481 want := "string, invalid WriteHeader code 0"
1482 if got != want {
1483 t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
1484 }
1485 gotpanic <- true
1486
1487
1488
1489
1490 w.WriteHeader(503)
1491 }
1492 }()
1493 w.WriteHeader(0)
1494 }))
1495 res, err := cst.c.Get(cst.ts.URL)
1496 if err != nil {
1497 t.Fatal(err)
1498 }
1499 if res.StatusCode != 503 {
1500 t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
1501 }
1502 if !<-gotpanic {
1503 t.Error("expected panic in handler")
1504 }
1505 }
1506
1507
1508
1509 func TestWriteHeaderNoCodeCheck(t *testing.T) {
1510 run(t, func(t *testing.T, mode testMode) {
1511 testWriteHeaderAfterWrite(t, mode, false)
1512 })
1513 }
1514 func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) {
1515 testWriteHeaderAfterWrite(t, http1Mode, true)
1516 }
1517 func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) {
1518 var errorLog lockedBytesBuffer
1519 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1520 if hijack {
1521 conn, _, _ := w.(Hijacker).Hijack()
1522 defer conn.Close()
1523 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
1524 w.WriteHeader(0)
1525 conn.Write([]byte("bar"))
1526 return
1527 }
1528 io.WriteString(w, "foo")
1529 w.(Flusher).Flush()
1530 w.WriteHeader(0)
1531 io.WriteString(w, "bar")
1532 }), func(ts *httptest.Server) {
1533 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1534 })
1535 res, err := cst.c.Get(cst.ts.URL)
1536 if err != nil {
1537 t.Fatal(err)
1538 }
1539 defer res.Body.Close()
1540 body, err := io.ReadAll(res.Body)
1541 if err != nil {
1542 t.Fatal(err)
1543 }
1544 if got, want := string(body), "foobar"; got != want {
1545 t.Errorf("got = %q; want %q", got, want)
1546 }
1547
1548
1549 if mode == http2Mode {
1550
1551
1552 return
1553 }
1554 gotLog := strings.TrimSpace(errorLog.String())
1555 wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1556 if hijack {
1557 wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1558 }
1559 if !strings.HasPrefix(gotLog, wantLog) {
1560 t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
1561 }
1562 }
1563
1564 func TestBidiStreamReverseProxy(t *testing.T) {
1565 run(t, testBidiStreamReverseProxy, []testMode{http2Mode})
1566 }
1567 func testBidiStreamReverseProxy(t *testing.T, mode testMode) {
1568 backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1569 if _, err := io.Copy(w, r.Body); err != nil {
1570 log.Printf("bidi backend copy: %v", err)
1571 }
1572 }))
1573
1574 backURL, err := url.Parse(backend.ts.URL)
1575 if err != nil {
1576 t.Fatal(err)
1577 }
1578 rp := httputil.NewSingleHostReverseProxy(backURL)
1579 rp.Transport = backend.tr
1580 proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1581 rp.ServeHTTP(w, r)
1582 }))
1583
1584 bodyRes := make(chan any, 1)
1585 pr, pw := io.Pipe()
1586 req, _ := NewRequest("PUT", proxy.ts.URL, pr)
1587 const size = 4 << 20
1588 go func() {
1589 h := sha1.New()
1590 _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
1591 go pw.Close()
1592 if err != nil {
1593 bodyRes <- err
1594 } else {
1595 bodyRes <- h
1596 }
1597 }()
1598 res, err := backend.c.Do(req)
1599 if err != nil {
1600 t.Fatal(err)
1601 }
1602 defer res.Body.Close()
1603 hgot := sha1.New()
1604 n, err := io.Copy(hgot, res.Body)
1605 if err != nil {
1606 t.Fatal(err)
1607 }
1608 if n != size {
1609 t.Fatalf("got %d bytes; want %d", n, size)
1610 }
1611 select {
1612 case v := <-bodyRes:
1613 switch v := v.(type) {
1614 default:
1615 t.Fatalf("body copy: %v", err)
1616 case hash.Hash:
1617 if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
1618 t.Errorf("written bytes didn't match received bytes")
1619 }
1620 }
1621 case <-time.After(10 * time.Second):
1622 t.Fatal("timeout")
1623 }
1624
1625 }
1626
1627
1628 func TestH12_WebSocketUpgrade(t *testing.T) {
1629 h12Compare{
1630 Handler: func(w ResponseWriter, r *Request) {
1631 h := w.Header()
1632 h.Set("Foo", "bar")
1633 },
1634 ReqFunc: func(c *Client, url string) (*Response, error) {
1635 req, _ := NewRequest("GET", url, nil)
1636 req.Header.Set("Connection", "Upgrade")
1637 req.Header.Set("Upgrade", "WebSocket")
1638 return c.Do(req)
1639 },
1640 EarlyCheckResponse: func(proto string, res *Response) {
1641 if res.Proto != "HTTP/1.1" {
1642 t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
1643 }
1644 res.Proto = "HTTP/IGNORE"
1645 },
1646 }.run(t)
1647 }
1648
1649 func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) }
1650 func testIdentityTransferEncoding(t *testing.T, mode testMode) {
1651 const body = "body"
1652 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1653 gotBody, _ := io.ReadAll(r.Body)
1654 if got, want := string(gotBody), body; got != want {
1655 t.Errorf("got request body = %q; want %q", got, want)
1656 }
1657 w.Header().Set("Transfer-Encoding", "identity")
1658 w.WriteHeader(StatusOK)
1659 w.(Flusher).Flush()
1660 io.WriteString(w, body)
1661 }))
1662 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
1663 res, err := cst.c.Do(req)
1664 if err != nil {
1665 t.Fatal(err)
1666 }
1667 defer res.Body.Close()
1668 gotBody, err := io.ReadAll(res.Body)
1669 if err != nil {
1670 t.Fatal(err)
1671 }
1672 if got, want := string(gotBody), body; got != want {
1673 t.Errorf("got response body = %q; want %q", got, want)
1674 }
1675 }
1676
1677 func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) }
1678 func testEarlyHintsRequest(t *testing.T, mode testMode) {
1679 var wg sync.WaitGroup
1680 wg.Add(1)
1681 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1682 h := w.Header()
1683
1684 h.Add("Content-Length", "123")
1685 h.Add("Link", "</style.css>; rel=preload; as=style")
1686 h.Add("Link", "</script.js>; rel=preload; as=script")
1687 w.WriteHeader(StatusEarlyHints)
1688
1689 wg.Wait()
1690
1691 h.Add("Link", "</foo.js>; rel=preload; as=script")
1692 w.WriteHeader(StatusEarlyHints)
1693
1694 w.Write([]byte("Hello"))
1695 }))
1696
1697 checkLinkHeaders := func(t *testing.T, expected, got []string) {
1698 t.Helper()
1699
1700 if len(expected) != len(got) {
1701 t.Errorf("got %d expected %d", len(got), len(expected))
1702 }
1703
1704 for i := range expected {
1705 if expected[i] != got[i] {
1706 t.Errorf("got %q expected %q", got[i], expected[i])
1707 }
1708 }
1709 }
1710
1711 checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
1712 t.Helper()
1713
1714 for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
1715 if v, ok := header[h]; ok {
1716 t.Errorf("%s is %q; must not be sent", h, v)
1717 }
1718 }
1719 }
1720
1721 var respCounter uint8
1722 trace := &httptrace.ClientTrace{
1723 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1724 switch respCounter {
1725 case 0:
1726 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
1727 checkExcludedHeaders(t, header)
1728
1729 wg.Done()
1730 case 1:
1731 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
1732 checkExcludedHeaders(t, header)
1733
1734 default:
1735 t.Error("Unexpected 1xx response")
1736 }
1737
1738 respCounter++
1739
1740 return nil
1741 },
1742 }
1743 req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
1744
1745 res, err := cst.c.Do(req)
1746 if err != nil {
1747 t.Fatal(err)
1748 }
1749 defer res.Body.Close()
1750
1751 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
1752 if cl := res.Header.Get("Content-Length"); cl != "123" {
1753 t.Errorf("Content-Length is %q; want 123", cl)
1754 }
1755
1756 body, _ := io.ReadAll(res.Body)
1757 if string(body) != "Hello" {
1758 t.Errorf("Read body %q; want Hello", body)
1759 }
1760 }
1761
View as plain text