Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 "crypto/tls"
16 "encoding/json"
17 "errors"
18 "fmt"
19 "internal/testenv"
20 "io"
21 "log"
22 "math/rand"
23 "mime/multipart"
24 "net"
25 . "net/http"
26 "net/http/httptest"
27 "net/http/httptrace"
28 "net/http/httputil"
29 "net/http/internal"
30 "net/http/internal/testcert"
31 "net/url"
32 "os"
33 "os/exec"
34 "path/filepath"
35 "reflect"
36 "regexp"
37 "runtime"
38 "strconv"
39 "strings"
40 "sync"
41 "sync/atomic"
42 "syscall"
43 "testing"
44 "time"
45 )
46
47 type dummyAddr string
48 type oneConnListener struct {
49 conn net.Conn
50 }
51
52 func (l *oneConnListener) Accept() (c net.Conn, err error) {
53 c = l.conn
54 if c == nil {
55 err = io.EOF
56 return
57 }
58 err = nil
59 l.conn = nil
60 return
61 }
62
63 func (l *oneConnListener) Close() error {
64 return nil
65 }
66
67 func (l *oneConnListener) Addr() net.Addr {
68 return dummyAddr("test-address")
69 }
70
71 func (a dummyAddr) Network() string {
72 return string(a)
73 }
74
75 func (a dummyAddr) String() string {
76 return string(a)
77 }
78
79 type noopConn struct{}
80
81 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
82 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
83 func (noopConn) SetDeadline(t time.Time) error { return nil }
84 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
85 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
86
87 type rwTestConn struct {
88 io.Reader
89 io.Writer
90 noopConn
91
92 closeFunc func() error
93 closec chan bool
94 }
95
96 func (c *rwTestConn) Close() error {
97 if c.closeFunc != nil {
98 return c.closeFunc()
99 }
100 select {
101 case c.closec <- true:
102 default:
103 }
104 return nil
105 }
106
107 type testConn struct {
108 readMu sync.Mutex
109 readBuf bytes.Buffer
110 writeBuf bytes.Buffer
111 closec chan bool
112 noopConn
113 }
114
115 func (c *testConn) Read(b []byte) (int, error) {
116 c.readMu.Lock()
117 defer c.readMu.Unlock()
118 return c.readBuf.Read(b)
119 }
120
121 func (c *testConn) Write(b []byte) (int, error) {
122 return c.writeBuf.Write(b)
123 }
124
125 func (c *testConn) Close() error {
126 select {
127 case c.closec <- true:
128 default:
129 }
130 return nil
131 }
132
133
134
135 func reqBytes(req string) []byte {
136 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
137 }
138
139 type handlerTest struct {
140 logbuf bytes.Buffer
141 handler Handler
142 }
143
144 func newHandlerTest(h Handler) handlerTest {
145 return handlerTest{handler: h}
146 }
147
148 func (ht *handlerTest) rawResponse(req string) string {
149 reqb := reqBytes(req)
150 var output strings.Builder
151 conn := &rwTestConn{
152 Reader: bytes.NewReader(reqb),
153 Writer: &output,
154 closec: make(chan bool, 1),
155 }
156 ln := &oneConnListener{conn: conn}
157 srv := &Server{
158 ErrorLog: log.New(&ht.logbuf, "", 0),
159 Handler: ht.handler,
160 }
161 go srv.Serve(ln)
162 <-conn.closec
163 return output.String()
164 }
165
166 func TestConsumingBodyOnNextConn(t *testing.T) {
167 t.Parallel()
168 defer afterTest(t)
169 conn := new(testConn)
170 for i := 0; i < 2; i++ {
171 conn.readBuf.Write([]byte(
172 "POST / HTTP/1.1\r\n" +
173 "Host: test\r\n" +
174 "Content-Length: 11\r\n" +
175 "\r\n" +
176 "foo=1&bar=1"))
177 }
178
179 reqNum := 0
180 ch := make(chan *Request)
181 servech := make(chan error)
182 listener := &oneConnListener{conn}
183 handler := func(res ResponseWriter, req *Request) {
184 reqNum++
185 ch <- req
186 }
187
188 go func() {
189 servech <- Serve(listener, HandlerFunc(handler))
190 }()
191
192 var req *Request
193 req = <-ch
194 if req == nil {
195 t.Fatal("Got nil first request.")
196 }
197 if req.Method != "POST" {
198 t.Errorf("For request #1's method, got %q; expected %q",
199 req.Method, "POST")
200 }
201
202 req = <-ch
203 if req == nil {
204 t.Fatal("Got nil first request.")
205 }
206 if req.Method != "POST" {
207 t.Errorf("For request #2's method, got %q; expected %q",
208 req.Method, "POST")
209 }
210
211 if serveerr := <-servech; serveerr != io.EOF {
212 t.Errorf("Serve returned %q; expected EOF", serveerr)
213 }
214 }
215
216 type stringHandler string
217
218 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
219 w.Header().Set("Result", string(s))
220 }
221
222 var handlers = []struct {
223 pattern string
224 msg string
225 }{
226 {"/", "Default"},
227 {"/someDir/", "someDir"},
228 {"/#/", "hash"},
229 {"someHost.com/someDir/", "someHost.com/someDir"},
230 }
231
232 var vtests = []struct {
233 url string
234 expected string
235 }{
236 {"http://localhost/someDir/apage", "someDir"},
237 {"http://localhost/%23/apage", "hash"},
238 {"http://localhost/otherDir/apage", "Default"},
239 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
240 {"http://otherHost.com/someDir/apage", "someDir"},
241 {"http://otherHost.com/aDir/apage", "Default"},
242
243 {"http://localhost/someDir", "/someDir/"},
244 {"http://localhost/%23", "/%23/"},
245 {"http://someHost.com/someDir", "/someDir/"},
246 }
247
248 func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
249 func testHostHandlers(t *testing.T, mode testMode) {
250 mux := NewServeMux()
251 for _, h := range handlers {
252 mux.Handle(h.pattern, stringHandler(h.msg))
253 }
254 ts := newClientServerTest(t, mode, mux).ts
255
256 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
257 if err != nil {
258 t.Fatal(err)
259 }
260 defer conn.Close()
261 cc := httputil.NewClientConn(conn, nil)
262 for _, vt := range vtests {
263 var r *Response
264 var req Request
265 if req.URL, err = url.Parse(vt.url); err != nil {
266 t.Errorf("cannot parse url: %v", err)
267 continue
268 }
269 if err := cc.Write(&req); err != nil {
270 t.Errorf("writing request: %v", err)
271 continue
272 }
273 r, err := cc.Read(&req)
274 if err != nil {
275 t.Errorf("reading response: %v", err)
276 continue
277 }
278 switch r.StatusCode {
279 case StatusOK:
280 s := r.Header.Get("Result")
281 if s != vt.expected {
282 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
283 }
284 case StatusMovedPermanently:
285 s := r.Header.Get("Location")
286 if s != vt.expected {
287 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
288 }
289 default:
290 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
291 }
292 }
293 }
294
295 var serveMuxRegister = []struct {
296 pattern string
297 h Handler
298 }{
299 {"/dir/", serve(200)},
300 {"/search", serve(201)},
301 {"codesearch.google.com/search", serve(202)},
302 {"codesearch.google.com/", serve(203)},
303 {"example.com/", HandlerFunc(checkQueryStringHandler)},
304 }
305
306
307 func serve(code int) HandlerFunc {
308 return func(w ResponseWriter, r *Request) {
309 w.WriteHeader(code)
310 }
311 }
312
313
314
315
316 func checkQueryStringHandler(w ResponseWriter, r *Request) {
317 u := *r.URL
318 u.Scheme = "http"
319 u.Host = r.Host
320 u.RawQuery = ""
321 if "http://"+r.URL.RawQuery == u.String() {
322 w.WriteHeader(200)
323 } else {
324 w.WriteHeader(500)
325 }
326 }
327
328 var serveMuxTests = []struct {
329 method string
330 host string
331 path string
332 code int
333 pattern string
334 }{
335 {"GET", "google.com", "/", 404, ""},
336 {"GET", "google.com", "/dir", 301, "/dir/"},
337 {"GET", "google.com", "/dir/", 200, "/dir/"},
338 {"GET", "google.com", "/dir/file", 200, "/dir/"},
339 {"GET", "google.com", "/search", 201, "/search"},
340 {"GET", "google.com", "/search/", 404, ""},
341 {"GET", "google.com", "/search/foo", 404, ""},
342 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
343 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
344 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
345 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
346 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
347 {"GET", "images.google.com", "/search", 201, "/search"},
348 {"GET", "images.google.com", "/search/", 404, ""},
349 {"GET", "images.google.com", "/search/foo", 404, ""},
350 {"GET", "google.com", "/../search", 301, "/search"},
351 {"GET", "google.com", "/dir/..", 301, ""},
352 {"GET", "google.com", "/dir/..", 301, ""},
353 {"GET", "google.com", "/dir/./file", 301, "/dir/"},
354
355
356
357 {"CONNECT", "google.com", "/dir", 301, "/dir/"},
358 {"CONNECT", "google.com", "/../search", 404, ""},
359 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
360 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
361 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
362 }
363
364 func TestServeMuxHandler(t *testing.T) {
365 setParallel(t)
366 mux := NewServeMux()
367 for _, e := range serveMuxRegister {
368 mux.Handle(e.pattern, e.h)
369 }
370
371 for _, tt := range serveMuxTests {
372 r := &Request{
373 Method: tt.method,
374 Host: tt.host,
375 URL: &url.URL{
376 Path: tt.path,
377 },
378 }
379 h, pattern := mux.Handler(r)
380 rr := httptest.NewRecorder()
381 h.ServeHTTP(rr, r)
382 if pattern != tt.pattern || rr.Code != tt.code {
383 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
384 }
385 }
386 }
387
388
389 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
390 setParallel(t)
391 defer func() {
392 if err := recover(); err == nil {
393 t.Error("expected call to mux.HandleFunc to panic")
394 }
395 }()
396 mux := NewServeMux()
397 mux.HandleFunc("/", nil)
398 }
399
400 var serveMuxTests2 = []struct {
401 method string
402 host string
403 url string
404 code int
405 redirOk bool
406 }{
407 {"GET", "google.com", "/", 404, false},
408 {"GET", "example.com", "/test/?example.com/test/", 200, false},
409 {"GET", "example.com", "test/?example.com/test/", 200, true},
410 }
411
412
413
414 func TestServeMuxHandlerRedirects(t *testing.T) {
415 setParallel(t)
416 mux := NewServeMux()
417 for _, e := range serveMuxRegister {
418 mux.Handle(e.pattern, e.h)
419 }
420
421 for _, tt := range serveMuxTests2 {
422 tries := 1
423 turl := tt.url
424 for {
425 u, e := url.Parse(turl)
426 if e != nil {
427 t.Fatal(e)
428 }
429 r := &Request{
430 Method: tt.method,
431 Host: tt.host,
432 URL: u,
433 }
434 h, _ := mux.Handler(r)
435 rr := httptest.NewRecorder()
436 h.ServeHTTP(rr, r)
437 if rr.Code != 301 {
438 if rr.Code != tt.code {
439 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
440 }
441 break
442 }
443 if !tt.redirOk {
444 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
445 break
446 }
447 turl = rr.HeaderMap.Get("Location")
448 tries--
449 }
450 if tries < 0 {
451 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
452 }
453 }
454 }
455
456
457 func TestMuxRedirectLeadingSlashes(t *testing.T) {
458 setParallel(t)
459 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
460 for _, path := range paths {
461 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
462 if err != nil {
463 t.Errorf("%s", err)
464 }
465 mux := NewServeMux()
466 resp := httptest.NewRecorder()
467
468 mux.ServeHTTP(resp, req)
469
470 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
471 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
472 return
473 }
474
475 if code, expected := resp.Code, StatusMovedPermanently; code != expected {
476 t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
477 return
478 }
479 }
480 }
481
482
483
484
485
486 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
487 run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
488 }
489 func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
490 writeBackQuery := func(w ResponseWriter, r *Request) {
491 fmt.Fprintf(w, "%s", r.URL.RawQuery)
492 }
493
494 mux := NewServeMux()
495 mux.HandleFunc("/testOne", writeBackQuery)
496 mux.HandleFunc("/testTwo/", writeBackQuery)
497 mux.HandleFunc("/testThree", writeBackQuery)
498 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
499 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
500 })
501
502 ts := newClientServerTest(t, mode, mux).ts
503
504 tests := [...]struct {
505 path string
506 method string
507 want string
508 statusOk bool
509 }{
510 0: {"/testOne?this=that", "GET", "this=that", true},
511 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
512 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
513 3: {"/testTwo?", "GET", "", true},
514 4: {"/testThree?foo", "GET", "foo", true},
515 5: {"/testThree/?foo", "GET", "foo:bar", true},
516 6: {"/testThree?foo", "CONNECT", "foo", true},
517 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
518
519
520 8: {"/testOne/foo/..?foo", "GET", "foo", true},
521 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
522 }
523
524 for i, tt := range tests {
525 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
526 res, err := ts.Client().Do(req)
527 if err != nil {
528 continue
529 }
530 slurp, _ := io.ReadAll(res.Body)
531 res.Body.Close()
532 if !tt.statusOk {
533 if got, want := res.StatusCode, 404; got != want {
534 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
535 }
536 }
537 if got, want := string(slurp), tt.want; got != want {
538 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
539 }
540 }
541 }
542
543 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
544 setParallel(t)
545
546 mux := NewServeMux()
547 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
548 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
549 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
550 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
551 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
552 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
553
554 tests := []struct {
555 method string
556 url string
557 code int
558 loc string
559 want string
560 }{
561 {"GET", "http://example.com/", 404, "", ""},
562 {"GET", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
563 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
564 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
565 {"GET", "http://example.com/pkg/baz", 301, "/pkg/baz/", ""},
566 {"GET", "http://example.com:3000/pkg/foo", 301, "/pkg/foo/", ""},
567 {"CONNECT", "http://example.com/", 404, "", ""},
568 {"CONNECT", "http://example.com:3000/", 404, "", ""},
569 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
570 {"CONNECT", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
571 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
572 {"CONNECT", "http://example.com:3000/pkg/baz", 301, "/pkg/baz/", ""},
573 {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""},
574 }
575
576 for i, tt := range tests {
577 req, _ := NewRequest(tt.method, tt.url, nil)
578 w := httptest.NewRecorder()
579 mux.ServeHTTP(w, req)
580
581 if got, want := w.Code, tt.code; got != want {
582 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
583 }
584
585 if tt.code == 301 {
586 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
587 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
588 }
589 } else {
590 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
591 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
592 }
593 }
594 }
595 }
596
597 func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
598 func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
599 mux := NewServeMux()
600 newClientServerTest(t, mode, mux)
601 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
602 }
603
604 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
605 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
606 func benchmarkServeMux(b *testing.B, runHandler bool) {
607 type test struct {
608 path string
609 code int
610 req *Request
611 }
612
613
614 var tests []test
615 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
616 for _, e := range endpoints {
617 for i := 200; i < 230; i++ {
618 p := fmt.Sprintf("/%s/%d/", e, i)
619 tests = append(tests, test{
620 path: p,
621 code: i,
622 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
623 })
624 }
625 }
626 mux := NewServeMux()
627 for _, tt := range tests {
628 mux.Handle(tt.path, serve(tt.code))
629 }
630
631 rw := httptest.NewRecorder()
632 b.ReportAllocs()
633 b.ResetTimer()
634 for i := 0; i < b.N; i++ {
635 for _, tt := range tests {
636 *rw = httptest.ResponseRecorder{}
637 h, pattern := mux.Handler(tt.req)
638 if runHandler {
639 h.ServeHTTP(rw, tt.req)
640 if pattern != tt.path || rw.Code != tt.code {
641 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
642 }
643 }
644 }
645 }
646 }
647
648 func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
649 func testServerTimeouts(t *testing.T, mode testMode) {
650
651 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
652 for i, timeout := range tries {
653 err := testServerTimeoutsWithTimeout(t, timeout, mode)
654 if err == nil {
655 return
656 }
657 t.Logf("failed at %v: %v", timeout, err)
658 if i != len(tries)-1 {
659 t.Logf("retrying at %v ...", tries[i+1])
660 }
661 }
662 t.Fatal("all attempts failed")
663 }
664
665 func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
666 reqNum := 0
667 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
668 reqNum++
669 fmt.Fprintf(res, "req=%d", reqNum)
670 }), func(ts *httptest.Server) {
671 ts.Config.ReadTimeout = timeout
672 ts.Config.WriteTimeout = timeout
673 }).ts
674
675
676 c := ts.Client()
677 r, err := c.Get(ts.URL)
678 if err != nil {
679 return fmt.Errorf("http Get #1: %v", err)
680 }
681 got, err := io.ReadAll(r.Body)
682 expected := "req=1"
683 if string(got) != expected || err != nil {
684 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
685 string(got), err, expected)
686 }
687
688
689 t1 := time.Now()
690 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
691 if err != nil {
692 return fmt.Errorf("Dial: %v", err)
693 }
694 buf := make([]byte, 1)
695 n, err := conn.Read(buf)
696 conn.Close()
697 latency := time.Since(t1)
698 if n != 0 || err != io.EOF {
699 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
700 }
701 minLatency := timeout / 5 * 4
702 if latency < minLatency {
703 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
704 }
705
706
707
708
709 r, err = c.Get(ts.URL)
710 if err != nil {
711 return fmt.Errorf("http Get #2: %v", err)
712 }
713 got, err = io.ReadAll(r.Body)
714 r.Body.Close()
715 expected = "req=2"
716 if string(got) != expected || err != nil {
717 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
718 }
719
720 if !testing.Short() {
721 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
722 if err != nil {
723 return fmt.Errorf("long Dial: %v", err)
724 }
725 defer conn.Close()
726 go io.Copy(io.Discard, conn)
727 for i := 0; i < 5; i++ {
728 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
729 if err != nil {
730 return fmt.Errorf("on write %d: %v", i, err)
731 }
732 time.Sleep(timeout / 2)
733 }
734 }
735 return nil
736 }
737
738 func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout) }
739 func testServerReadTimeout(t *testing.T, mode testMode) {
740 respBody := "response body"
741 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
742 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
743 _, err := io.Copy(io.Discard, req.Body)
744 if !errors.Is(err, os.ErrDeadlineExceeded) {
745 t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
746 }
747 res.Write([]byte(respBody))
748 }), func(ts *httptest.Server) {
749 ts.Config.ReadHeaderTimeout = -1
750 ts.Config.ReadTimeout = timeout
751 })
752 pr, pw := io.Pipe()
753 res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
754 if err != nil {
755 t.Logf("Get error, retrying: %v", err)
756 cst.close()
757 continue
758 }
759 defer res.Body.Close()
760 got, err := io.ReadAll(res.Body)
761 if string(got) != respBody || err != nil {
762 t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
763 }
764 pw.Close()
765 break
766 }
767 }
768
769 func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout) }
770 func testServerWriteTimeout(t *testing.T, mode testMode) {
771 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
772 errc := make(chan error, 2)
773 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
774 errc <- nil
775 _, err := io.Copy(res, neverEnding('a'))
776 errc <- err
777 }), func(ts *httptest.Server) {
778 ts.Config.WriteTimeout = timeout
779 })
780 res, err := cst.c.Get(cst.ts.URL)
781 if err != nil {
782
783 t.Logf("Get error, retrying: %v", err)
784 cst.close()
785 continue
786 }
787 defer res.Body.Close()
788 _, err = io.Copy(io.Discard, res.Body)
789 if err == nil {
790 t.Errorf("client reading from truncated request body: got nil error, want non-nil")
791 }
792 select {
793 case <-errc:
794 err = <-errc
795 if !errors.Is(err, os.ErrDeadlineExceeded) {
796 t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
797 }
798 return
799 default:
800
801 t.Logf("handler didn't run, retrying")
802 cst.close()
803 }
804 }
805 }
806
807
808 func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
809 run(t, testWriteDeadlineExtendedOnNewRequest)
810 }
811 func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
812 if testing.Short() {
813 t.Skip("skipping in short mode")
814 }
815 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
816 func(ts *httptest.Server) {
817 ts.Config.WriteTimeout = 250 * time.Millisecond
818 },
819 ).ts
820
821 c := ts.Client()
822
823 for i := 1; i <= 3; i++ {
824 req, err := NewRequest("GET", ts.URL, nil)
825 if err != nil {
826 t.Fatal(err)
827 }
828
829 r, err := c.Do(req)
830 if err != nil {
831 t.Fatalf("http2 Get #%d: %v", i, err)
832 }
833 r.Body.Close()
834 time.Sleep(ts.Config.WriteTimeout / 2)
835 }
836 }
837
838
839
840 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
841 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
842 for i, timeout := range tries {
843 err := testFunc(timeout)
844 if err == nil {
845 return
846 }
847 t.Logf("failed at %v: %v", timeout, err)
848 if i != len(tries)-1 {
849 t.Logf("retrying at %v ...", tries[i+1])
850 }
851 }
852 t.Fatal("all attempts failed")
853 }
854
855
856 func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
857 if testing.Short() {
858 t.Skip("skipping in short mode")
859 }
860 setParallel(t)
861 run(t, func(t *testing.T, mode testMode) {
862 tryTimeouts(t, func(timeout time.Duration) error {
863 return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
864 })
865 })
866 }
867
868 func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
869 reqNum := 0
870 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
871 reqNum++
872 if reqNum == 1 {
873 return
874 }
875 time.Sleep(timeout)
876 }), func(ts *httptest.Server) {
877 ts.Config.WriteTimeout = timeout / 2
878 }).ts
879
880 c := ts.Client()
881
882 req, err := NewRequest("GET", ts.URL, nil)
883 if err != nil {
884 return fmt.Errorf("NewRequest: %v", err)
885 }
886 r, err := c.Do(req)
887 if err != nil {
888 return fmt.Errorf("Get #1: %v", err)
889 }
890 r.Body.Close()
891
892 req, err = NewRequest("GET", ts.URL, nil)
893 if err != nil {
894 return fmt.Errorf("NewRequest: %v", err)
895 }
896 r, err = c.Do(req)
897 if err == nil {
898 r.Body.Close()
899 return fmt.Errorf("Get #2 expected error, got nil")
900 }
901 if mode == http2Mode {
902 expected := "stream ID 3; INTERNAL_ERROR"
903 if !strings.Contains(err.Error(), expected) {
904 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
905 }
906 }
907 return nil
908 }
909
910
911 func TestNoWriteDeadline(t *testing.T) {
912 if testing.Short() {
913 t.Skip("skipping in short mode")
914 }
915 setParallel(t)
916 defer afterTest(t)
917 run(t, func(t *testing.T, mode testMode) {
918 tryTimeouts(t, func(timeout time.Duration) error {
919 return testNoWriteDeadline(t, mode, timeout)
920 })
921 })
922 }
923
924 func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
925 reqNum := 0
926 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
927 reqNum++
928 if reqNum == 1 {
929 return
930 }
931 time.Sleep(timeout)
932 })).ts
933
934 c := ts.Client()
935
936 for i := 0; i < 2; i++ {
937 req, err := NewRequest("GET", ts.URL, nil)
938 if err != nil {
939 return fmt.Errorf("NewRequest: %v", err)
940 }
941 r, err := c.Do(req)
942 if err != nil {
943 return fmt.Errorf("Get #%d: %v", i, err)
944 }
945 r.Body.Close()
946 }
947 return nil
948 }
949
950
951
952
953 func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
954 func testOnlyWriteTimeout(t *testing.T, mode testMode) {
955 var (
956 mu sync.RWMutex
957 conn net.Conn
958 )
959 var afterTimeoutErrc = make(chan error, 1)
960 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
961 buf := make([]byte, 512<<10)
962 _, err := w.Write(buf)
963 if err != nil {
964 t.Errorf("handler Write error: %v", err)
965 return
966 }
967 mu.RLock()
968 defer mu.RUnlock()
969 if conn == nil {
970 t.Error("no established connection found")
971 return
972 }
973 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
974 _, err = w.Write(buf)
975 afterTimeoutErrc <- err
976 }), func(ts *httptest.Server) {
977 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
978 }).ts
979
980 c := ts.Client()
981
982 err := func() error {
983 res, err := c.Get(ts.URL)
984 if err != nil {
985 return err
986 }
987 _, err = io.Copy(io.Discard, res.Body)
988 res.Body.Close()
989 return err
990 }()
991 if err == nil {
992 t.Errorf("expected an error copying body from Get request")
993 }
994
995 if err := <-afterTimeoutErrc; err == nil {
996 t.Error("expected write error after timeout")
997 }
998 }
999
1000
1001 type trackLastConnListener struct {
1002 net.Listener
1003
1004 mu *sync.RWMutex
1005 last *net.Conn
1006 }
1007
1008 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
1009 c, err = l.Listener.Accept()
1010 if err == nil {
1011 l.mu.Lock()
1012 *l.last = c
1013 l.mu.Unlock()
1014 }
1015 return
1016 }
1017
1018
1019 func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
1020 func testIdentityResponse(t *testing.T, mode testMode) {
1021 if mode == http2Mode {
1022 t.Skip("https://go.dev/issue/56019")
1023 }
1024
1025 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
1026 rw.Header().Set("Content-Length", "3")
1027 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
1028 switch {
1029 case req.FormValue("overwrite") == "1":
1030 _, err := rw.Write([]byte("foo TOO LONG"))
1031 if err != ErrContentLength {
1032 t.Errorf("expected ErrContentLength; got %v", err)
1033 }
1034 case req.FormValue("underwrite") == "1":
1035 rw.Header().Set("Content-Length", "500")
1036 rw.Write([]byte("too short"))
1037 default:
1038 rw.Write([]byte("foo"))
1039 }
1040 })
1041
1042 ts := newClientServerTest(t, mode, handler).ts
1043 c := ts.Client()
1044
1045
1046
1047
1048
1049 for _, te := range []string{"", "identity"} {
1050 url := ts.URL + "/?te=" + te
1051 res, err := c.Get(url)
1052 if err != nil {
1053 t.Fatalf("error with Get of %s: %v", url, err)
1054 }
1055 if cl, expected := res.ContentLength, int64(3); cl != expected {
1056 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1057 }
1058 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1059 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1060 }
1061 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1062 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1063 url, expected, tl, res.TransferEncoding)
1064 }
1065 res.Body.Close()
1066 }
1067
1068
1069 url := ts.URL + "/?overwrite=1"
1070 res, err := c.Get(url)
1071 if err != nil {
1072 t.Fatalf("error with Get of %s: %v", url, err)
1073 }
1074 res.Body.Close()
1075
1076 if mode != http1Mode {
1077 return
1078 }
1079
1080
1081
1082 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1083 if err != nil {
1084 t.Fatalf("error dialing: %v", err)
1085 }
1086 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1087 if err != nil {
1088 t.Fatalf("error writing: %v", err)
1089 }
1090
1091
1092 got, _ := io.ReadAll(conn)
1093 expectedSuffix := "\r\n\r\ntoo short"
1094 if !strings.HasSuffix(string(got), expectedSuffix) {
1095 t.Errorf("Expected output to end with %q; got response body %q",
1096 expectedSuffix, string(got))
1097 }
1098 }
1099
1100 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1101 setParallel(t)
1102 s := newClientServerTest(t, http1Mode, h).ts
1103
1104 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1105 if err != nil {
1106 t.Fatal("dial error:", err)
1107 }
1108 defer conn.Close()
1109
1110 _, err = fmt.Fprint(conn, req)
1111 if err != nil {
1112 t.Fatal("print error:", err)
1113 }
1114
1115 r := bufio.NewReader(conn)
1116 res, err := ReadResponse(r, &Request{Method: "GET"})
1117 if err != nil {
1118 t.Fatal("ReadResponse error:", err)
1119 }
1120
1121 _, err = io.ReadAll(r)
1122 if err != nil {
1123 t.Fatal("read error:", err)
1124 }
1125
1126 if !res.Close {
1127 t.Errorf("Response.Close = false; want true")
1128 }
1129 }
1130
1131 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1132 setParallel(t)
1133 ts := newClientServerTest(t, http1Mode, handler).ts
1134 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1135 if err != nil {
1136 t.Fatal(err)
1137 }
1138 defer conn.Close()
1139 br := bufio.NewReader(conn)
1140 for i := 0; i < 2; i++ {
1141 if _, err := io.WriteString(conn, req); err != nil {
1142 t.Fatal(err)
1143 }
1144 res, err := ReadResponse(br, nil)
1145 if err != nil {
1146 t.Fatalf("res %d: %v", i+1, err)
1147 }
1148 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1149 t.Fatalf("res %d body copy: %v", i+1, err)
1150 }
1151 res.Body.Close()
1152 }
1153 }
1154
1155
1156 func TestServeHTTP10Close(t *testing.T) {
1157 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1158 ServeFile(w, r, "testdata/file")
1159 }))
1160 }
1161
1162
1163 func TestClientCanClose(t *testing.T) {
1164 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1165
1166 }))
1167 }
1168
1169
1170
1171 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1172 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1173 w.Header().Set("Connection", "close")
1174 }))
1175 }
1176
1177 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1178 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1179 w.Header().Set("Connection", "close")
1180 }))
1181 }
1182
1183 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1184 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1185
1186
1187 }))
1188 }
1189
1190 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1191 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1192
1193
1194 func TestHTTP10KeepAlive204Response(t *testing.T) {
1195 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1196 }
1197
1198 func TestHTTP11KeepAlive204Response(t *testing.T) {
1199 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1200 }
1201
1202 func TestHTTP10KeepAlive304Response(t *testing.T) {
1203 testTCPConnectionStaysOpen(t,
1204 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1205 HandlerFunc(send304))
1206 }
1207
1208
1209 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
1210 func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
1211 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1212 w.(Flusher).Flush()
1213 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1214 }))
1215 type data struct {
1216 Addr string
1217 }
1218 var addrs [2]data
1219 for i := range addrs {
1220 res, err := cst.c.Get(cst.ts.URL)
1221 if err != nil {
1222 t.Fatal(err)
1223 }
1224 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1225 t.Fatal(err)
1226 }
1227 if addrs[i].Addr == "" {
1228 t.Fatal("no address")
1229 }
1230 res.Body.Close()
1231 }
1232 if addrs[0] != addrs[1] {
1233 t.Fatalf("connection not reused")
1234 }
1235 }
1236
1237 func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
1238 func testSetsRemoteAddr(t *testing.T, mode testMode) {
1239 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1240 fmt.Fprintf(w, "%s", r.RemoteAddr)
1241 }))
1242
1243 res, err := cst.c.Get(cst.ts.URL)
1244 if err != nil {
1245 t.Fatalf("Get error: %v", err)
1246 }
1247 body, err := io.ReadAll(res.Body)
1248 if err != nil {
1249 t.Fatalf("ReadAll error: %v", err)
1250 }
1251 ip := string(body)
1252 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1253 t.Fatalf("Expected local addr; got %q", ip)
1254 }
1255 }
1256
1257 type blockingRemoteAddrListener struct {
1258 net.Listener
1259 conns chan<- net.Conn
1260 }
1261
1262 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1263 c, err := l.Listener.Accept()
1264 if err != nil {
1265 return nil, err
1266 }
1267 brac := &blockingRemoteAddrConn{
1268 Conn: c,
1269 addrs: make(chan net.Addr, 1),
1270 }
1271 l.conns <- brac
1272 return brac, nil
1273 }
1274
1275 type blockingRemoteAddrConn struct {
1276 net.Conn
1277 addrs chan net.Addr
1278 }
1279
1280 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1281 return <-c.addrs
1282 }
1283
1284
1285 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1286 run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
1287 }
1288 func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
1289 conns := make(chan net.Conn)
1290 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1291 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1292 }), func(ts *httptest.Server) {
1293 ts.Listener = &blockingRemoteAddrListener{
1294 Listener: ts.Listener,
1295 conns: conns,
1296 }
1297 }).ts
1298
1299 c := ts.Client()
1300
1301 c.Transport.(*Transport).DisableKeepAlives = true
1302
1303 fetch := func(num int, response chan<- string) {
1304 resp, err := c.Get(ts.URL)
1305 if err != nil {
1306 t.Errorf("Request %d: %v", num, err)
1307 response <- ""
1308 return
1309 }
1310 defer resp.Body.Close()
1311 body, err := io.ReadAll(resp.Body)
1312 if err != nil {
1313 t.Errorf("Request %d: %v", num, err)
1314 response <- ""
1315 return
1316 }
1317 response <- string(body)
1318 }
1319
1320
1321 response1c := make(chan string, 1)
1322 go fetch(1, response1c)
1323
1324
1325 conn1 := <-conns
1326
1327
1328 response2c := make(chan string, 1)
1329 go fetch(2, response2c)
1330 conn2 := <-conns
1331
1332
1333 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1334 IP: net.ParseIP("12.12.12.12"), Port: 12}
1335
1336
1337 response2 := <-response2c
1338 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1339 t.Fatalf("response 2 addr = %q; want %q", g, e)
1340 }
1341
1342
1343 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1344 IP: net.ParseIP("21.21.21.21"), Port: 21}
1345
1346
1347 response1 := <-response1c
1348 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1349 t.Fatalf("response 1 addr = %q; want %q", g, e)
1350 }
1351 }
1352
1353
1354
1355 func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
1356 func testHeadResponses(t *testing.T, mode testMode) {
1357 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1358 _, err := w.Write([]byte("<html>"))
1359 if err != nil {
1360 t.Errorf("ResponseWriter.Write: %v", err)
1361 }
1362
1363
1364 _, err = io.Copy(w, strings.NewReader("789a"))
1365 if err != nil {
1366 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1367 }
1368 }))
1369 res, err := cst.c.Head(cst.ts.URL)
1370 if err != nil {
1371 t.Error(err)
1372 }
1373 if len(res.TransferEncoding) > 0 {
1374 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1375 }
1376 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1377 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1378 }
1379 if v := res.ContentLength; v != 10 {
1380 t.Errorf("Content-Length: %d; want 10", v)
1381 }
1382 body, err := io.ReadAll(res.Body)
1383 if err != nil {
1384 t.Error(err)
1385 }
1386 if len(body) > 0 {
1387 t.Errorf("got unexpected body %q", string(body))
1388 }
1389 }
1390
1391 func TestTLSHandshakeTimeout(t *testing.T) {
1392 run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
1393 }
1394 func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
1395 errc := make(chanWriter, 10)
1396 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
1397 func(ts *httptest.Server) {
1398 ts.Config.ReadTimeout = 250 * time.Millisecond
1399 ts.Config.ErrorLog = log.New(errc, "", 0)
1400 },
1401 ).ts
1402 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1403 if err != nil {
1404 t.Fatalf("Dial: %v", err)
1405 }
1406 defer conn.Close()
1407
1408 var buf [1]byte
1409 n, err := conn.Read(buf[:])
1410 if err == nil || n != 0 {
1411 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1412 }
1413
1414 v := <-errc
1415 if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1416 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1417 }
1418 }
1419
1420 func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
1421 func testTLSServer(t *testing.T, mode testMode) {
1422 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1423 if r.TLS != nil {
1424 w.Header().Set("X-TLS-Set", "true")
1425 if r.TLS.HandshakeComplete {
1426 w.Header().Set("X-TLS-HandshakeComplete", "true")
1427 }
1428 }
1429 }), func(ts *httptest.Server) {
1430 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1431 }).ts
1432
1433
1434
1435
1436
1437
1438 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1439 if err != nil {
1440 t.Fatalf("Dial: %v", err)
1441 }
1442 defer idleConn.Close()
1443
1444 if !strings.HasPrefix(ts.URL, "https://") {
1445 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1446 return
1447 }
1448 client := ts.Client()
1449 res, err := client.Get(ts.URL)
1450 if err != nil {
1451 t.Error(err)
1452 return
1453 }
1454 if res == nil {
1455 t.Errorf("got nil Response")
1456 return
1457 }
1458 defer res.Body.Close()
1459 if res.Header.Get("X-TLS-Set") != "true" {
1460 t.Errorf("expected X-TLS-Set response header")
1461 return
1462 }
1463 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1464 t.Errorf("expected X-TLS-HandshakeComplete header")
1465 }
1466 }
1467
1468 func TestServeTLS(t *testing.T) {
1469 CondSkipHTTP2(t)
1470
1471 defer afterTest(t)
1472 defer SetTestHookServerServe(nil)
1473
1474 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1475 if err != nil {
1476 t.Fatal(err)
1477 }
1478 tlsConf := &tls.Config{
1479 Certificates: []tls.Certificate{cert},
1480 }
1481
1482 ln := newLocalListener(t)
1483 defer ln.Close()
1484 addr := ln.Addr().String()
1485
1486 serving := make(chan bool, 1)
1487 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1488 serving <- true
1489 })
1490 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1491 s := &Server{
1492 Addr: addr,
1493 TLSConfig: tlsConf,
1494 Handler: handler,
1495 }
1496 errc := make(chan error, 1)
1497 go func() { errc <- s.ServeTLS(ln, "", "") }()
1498 select {
1499 case err := <-errc:
1500 t.Fatalf("ServeTLS: %v", err)
1501 case <-serving:
1502 }
1503
1504 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1505 InsecureSkipVerify: true,
1506 NextProtos: []string{"h2", "http/1.1"},
1507 })
1508 if err != nil {
1509 t.Fatal(err)
1510 }
1511 defer c.Close()
1512 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1513 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1514 }
1515 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1516 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1517 }
1518 }
1519
1520
1521 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1522 run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
1523 }
1524 func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
1525 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1526 t.Error("unexpected HTTPS request")
1527 }), func(ts *httptest.Server) {
1528 var errBuf bytes.Buffer
1529 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1530 }).ts
1531 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1532 if err != nil {
1533 t.Fatal(err)
1534 }
1535 defer conn.Close()
1536 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1537 slurp, err := io.ReadAll(conn)
1538 if err != nil {
1539 t.Fatal(err)
1540 }
1541 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1542 if !strings.HasPrefix(string(slurp), wantPrefix) {
1543 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1544 }
1545 }
1546
1547
1548 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1549 testAutomaticHTTP2_Serve(t, nil, true)
1550 }
1551
1552 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1553 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1554 }
1555
1556 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1557 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1558 }
1559
1560 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1561 setParallel(t)
1562 defer afterTest(t)
1563 ln := newLocalListener(t)
1564 ln.Close()
1565 var s Server
1566 s.TLSConfig = tlsConf
1567 if err := s.Serve(ln); err == nil {
1568 t.Fatal("expected an error")
1569 }
1570 gotH2 := s.TLSNextProto["h2"] != nil
1571 if gotH2 != wantH2 {
1572 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1573 }
1574 }
1575
1576 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1577 setParallel(t)
1578 defer afterTest(t)
1579 ln := newLocalListener(t)
1580 ln.Close()
1581 var s Server
1582
1583
1584 s.TLSConfig = &tls.Config{
1585 NextProtos: []string{"h2"},
1586 }
1587 if err := s.Serve(ln); err == nil {
1588 t.Fatal("expected an error")
1589 }
1590 on := s.TLSNextProto["h2"] != nil
1591 if !on {
1592 t.Errorf("http2 wasn't automatically enabled")
1593 }
1594 }
1595
1596 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1597 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1598 if err != nil {
1599 t.Fatal(err)
1600 }
1601 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1602 Certificates: []tls.Certificate{cert},
1603 })
1604 }
1605
1606 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1607 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1608 if err != nil {
1609 t.Fatal(err)
1610 }
1611 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1612 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1613 return &cert, nil
1614 },
1615 })
1616 }
1617
1618 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1619 CondSkipHTTP2(t)
1620
1621 defer afterTest(t)
1622 defer SetTestHookServerServe(nil)
1623 var ok bool
1624 var s *Server
1625 const maxTries = 5
1626 var ln net.Listener
1627 Try:
1628 for try := 0; try < maxTries; try++ {
1629 ln = newLocalListener(t)
1630 addr := ln.Addr().String()
1631 ln.Close()
1632 t.Logf("Got %v", addr)
1633 lnc := make(chan net.Listener, 1)
1634 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1635 lnc <- ln
1636 })
1637 s = &Server{
1638 Addr: addr,
1639 TLSConfig: tlsConf,
1640 }
1641 errc := make(chan error, 1)
1642 go func() { errc <- s.ListenAndServeTLS("", "") }()
1643 select {
1644 case err := <-errc:
1645 t.Logf("On try #%v: %v", try+1, err)
1646 continue
1647 case ln = <-lnc:
1648 ok = true
1649 t.Logf("Listening on %v", ln.Addr().String())
1650 break Try
1651 }
1652 }
1653 if !ok {
1654 t.Fatalf("Failed to start up after %d tries", maxTries)
1655 }
1656 defer ln.Close()
1657 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1658 InsecureSkipVerify: true,
1659 NextProtos: []string{"h2", "http/1.1"},
1660 })
1661 if err != nil {
1662 t.Fatal(err)
1663 }
1664 defer c.Close()
1665 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1666 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1667 }
1668 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1669 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1670 }
1671 }
1672
1673 type serverExpectTest struct {
1674 contentLength int
1675 chunked bool
1676 expectation string
1677 readBody bool
1678 expectedResponse string
1679 }
1680
1681 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
1682 return serverExpectTest{
1683 contentLength: contentLength,
1684 expectation: expectation,
1685 readBody: readBody,
1686 expectedResponse: expectedResponse,
1687 }
1688 }
1689
1690 var serverExpectTests = []serverExpectTest{
1691
1692 expectTest(100, "100-continue", true, "100 Continue"),
1693 expectTest(100, "100-cOntInUE", true, "100 Continue"),
1694
1695
1696 expectTest(100, "", true, "200 OK"),
1697
1698
1699
1700 expectTest(100, "100-continue", false, "401 Unauthorized"),
1701
1702 expectTest(100, "", false, "401 Unauthorized"),
1703
1704
1705 expectTest(0, "a-pony", false, "417 Expectation Failed"),
1706
1707
1708 expectTest(0, "100-continue", true, "200 OK"),
1709
1710 expectTest(0, "100-continue", false, "401 Unauthorized"),
1711
1712 {
1713 expectation: "100-continue",
1714 readBody: true,
1715 chunked: true,
1716 expectedResponse: "100 Continue",
1717 },
1718 }
1719
1720
1721
1722 func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
1723 func testServerExpect(t *testing.T, mode testMode) {
1724 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1725
1726
1727
1728 if strings.Contains(r.URL.RawQuery, "readbody=true") {
1729 io.ReadAll(r.Body)
1730 w.Write([]byte("Hi"))
1731 } else {
1732 w.WriteHeader(StatusUnauthorized)
1733 }
1734 })).ts
1735
1736 runTest := func(test serverExpectTest) {
1737 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1738 if err != nil {
1739 t.Fatalf("Dial: %v", err)
1740 }
1741 defer conn.Close()
1742
1743
1744
1745 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
1746
1747 wg := sync.WaitGroup{}
1748 wg.Add(1)
1749 defer wg.Wait()
1750
1751 go func() {
1752 defer wg.Done()
1753
1754 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
1755 if test.chunked {
1756 contentLen = "Transfer-Encoding: chunked"
1757 }
1758 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
1759 "Connection: close\r\n"+
1760 "%s\r\n"+
1761 "Expect: %s\r\nHost: foo\r\n\r\n",
1762 test.readBody, contentLen, test.expectation)
1763 if err != nil {
1764 t.Errorf("On test %#v, error writing request headers: %v", test, err)
1765 return
1766 }
1767 if writeBody {
1768 var targ io.WriteCloser = struct {
1769 io.Writer
1770 io.Closer
1771 }{
1772 conn,
1773 io.NopCloser(nil),
1774 }
1775 if test.chunked {
1776 targ = httputil.NewChunkedWriter(conn)
1777 }
1778 body := strings.Repeat("A", test.contentLength)
1779 _, err = fmt.Fprint(targ, body)
1780 if err == nil {
1781 err = targ.Close()
1782 }
1783 if err != nil {
1784 if !test.readBody {
1785
1786
1787 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
1788 return
1789 }
1790 t.Errorf("On test %#v, error writing request body: %v", test, err)
1791 }
1792 }
1793 }()
1794 bufr := bufio.NewReader(conn)
1795 line, err := bufr.ReadString('\n')
1796 if err != nil {
1797 if writeBody && !test.readBody {
1798
1799
1800
1801
1802
1803 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
1804 return
1805 }
1806 t.Fatalf("On test %#v, ReadString: %v", test, err)
1807 }
1808 if !strings.Contains(line, test.expectedResponse) {
1809 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
1810 }
1811 }
1812
1813 for _, test := range serverExpectTests {
1814 runTest(test)
1815 }
1816 }
1817
1818
1819
1820 func TestServerUnreadRequestBodyLittle(t *testing.T) {
1821 setParallel(t)
1822 defer afterTest(t)
1823 conn := new(testConn)
1824 body := strings.Repeat("x", 100<<10)
1825 conn.readBuf.Write([]byte(fmt.Sprintf(
1826 "POST / HTTP/1.1\r\n"+
1827 "Host: test\r\n"+
1828 "Content-Length: %d\r\n"+
1829 "\r\n", len(body))))
1830 conn.readBuf.Write([]byte(body))
1831
1832 done := make(chan bool)
1833
1834 readBufLen := func() int {
1835 conn.readMu.Lock()
1836 defer conn.readMu.Unlock()
1837 return conn.readBuf.Len()
1838 }
1839
1840 ls := &oneConnListener{conn}
1841 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
1842 defer close(done)
1843 if bufLen := readBufLen(); bufLen < len(body)/2 {
1844 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
1845 }
1846 rw.WriteHeader(200)
1847 rw.(Flusher).Flush()
1848 if g, e := readBufLen(), 0; g != e {
1849 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
1850 }
1851 if c := rw.Header().Get("Connection"); c != "" {
1852 t.Errorf(`Connection header = %q; want ""`, c)
1853 }
1854 }))
1855 <-done
1856 }
1857
1858
1859
1860
1861 func TestServerUnreadRequestBodyLarge(t *testing.T) {
1862 setParallel(t)
1863 if testing.Short() && testenv.Builder() == "" {
1864 t.Log("skipping in short mode")
1865 }
1866 conn := new(testConn)
1867 body := strings.Repeat("x", 1<<20)
1868 conn.readBuf.Write([]byte(fmt.Sprintf(
1869 "POST / HTTP/1.1\r\n"+
1870 "Host: test\r\n"+
1871 "Content-Length: %d\r\n"+
1872 "\r\n", len(body))))
1873 conn.readBuf.Write([]byte(body))
1874 conn.closec = make(chan bool, 1)
1875
1876 ls := &oneConnListener{conn}
1877 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
1878 if conn.readBuf.Len() < len(body)/2 {
1879 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
1880 }
1881 rw.WriteHeader(200)
1882 rw.(Flusher).Flush()
1883 if conn.readBuf.Len() < len(body)/2 {
1884 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
1885 }
1886 }))
1887 <-conn.closec
1888
1889 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
1890 t.Errorf("Expected a Connection: close header; got response: %s", res)
1891 }
1892 }
1893
1894 type handlerBodyCloseTest struct {
1895 bodySize int
1896 bodyChunked bool
1897 reqConnClose bool
1898
1899 wantEOFSearch bool
1900 wantNextReq bool
1901 }
1902
1903 func (t handlerBodyCloseTest) connectionHeader() string {
1904 if t.reqConnClose {
1905 return "Connection: close\r\n"
1906 }
1907 return ""
1908 }
1909
1910 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
1911
1912
1913 0: {
1914 bodySize: 20 << 10,
1915 bodyChunked: false,
1916 reqConnClose: false,
1917 wantEOFSearch: true,
1918 wantNextReq: true,
1919 },
1920
1921
1922
1923 1: {
1924 bodySize: 20 << 10,
1925 bodyChunked: true,
1926 reqConnClose: false,
1927 wantEOFSearch: true,
1928 wantNextReq: true,
1929 },
1930
1931
1932
1933
1934 2: {
1935 bodySize: 20 << 10,
1936 bodyChunked: false,
1937 reqConnClose: true,
1938 wantEOFSearch: false,
1939 wantNextReq: false,
1940 },
1941
1942
1943
1944
1945
1946
1947 3: {
1948 bodySize: 20 << 10,
1949 bodyChunked: true,
1950 reqConnClose: true,
1951 wantEOFSearch: true,
1952 wantNextReq: false,
1953 },
1954
1955
1956 4: {
1957 bodySize: 1 << 20,
1958 bodyChunked: false,
1959 reqConnClose: false,
1960 wantEOFSearch: false,
1961 wantNextReq: false,
1962 },
1963
1964
1965 5: {
1966 bodySize: 1 << 20,
1967 bodyChunked: true,
1968 reqConnClose: false,
1969 wantEOFSearch: true,
1970 wantNextReq: false,
1971 },
1972
1973
1974
1975
1976 6: {
1977 bodySize: 1 << 20,
1978 bodyChunked: true,
1979 reqConnClose: true,
1980 wantEOFSearch: true,
1981 wantNextReq: false,
1982 },
1983
1984
1985
1986 7: {
1987 bodySize: 1 << 20,
1988 bodyChunked: false,
1989 reqConnClose: true,
1990 wantEOFSearch: false,
1991 wantNextReq: false,
1992 },
1993 }
1994
1995 func TestHandlerBodyClose(t *testing.T) {
1996 setParallel(t)
1997 if testing.Short() && testenv.Builder() == "" {
1998 t.Skip("skipping in -short mode")
1999 }
2000 for i, tt := range handlerBodyCloseTests {
2001 testHandlerBodyClose(t, i, tt)
2002 }
2003 }
2004
2005 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2006 conn := new(testConn)
2007 body := strings.Repeat("x", tt.bodySize)
2008 if tt.bodyChunked {
2009 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2010 "Host: test\r\n" +
2011 tt.connectionHeader() +
2012 "Transfer-Encoding: chunked\r\n" +
2013 "\r\n")
2014 cw := internal.NewChunkedWriter(&conn.readBuf)
2015 io.WriteString(cw, body)
2016 cw.Close()
2017 conn.readBuf.WriteString("\r\n")
2018 } else {
2019 conn.readBuf.Write([]byte(fmt.Sprintf(
2020 "POST / HTTP/1.1\r\n"+
2021 "Host: test\r\n"+
2022 tt.connectionHeader()+
2023 "Content-Length: %d\r\n"+
2024 "\r\n", len(body))))
2025 conn.readBuf.Write([]byte(body))
2026 }
2027 if !tt.reqConnClose {
2028 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2029 }
2030 conn.closec = make(chan bool, 1)
2031
2032 readBufLen := func() int {
2033 conn.readMu.Lock()
2034 defer conn.readMu.Unlock()
2035 return conn.readBuf.Len()
2036 }
2037
2038 ls := &oneConnListener{conn}
2039 var numReqs int
2040 var size0, size1 int
2041 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2042 numReqs++
2043 if numReqs == 1 {
2044 size0 = readBufLen()
2045 req.Body.Close()
2046 size1 = readBufLen()
2047 }
2048 }))
2049 <-conn.closec
2050 if numReqs < 1 || numReqs > 2 {
2051 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2052 }
2053 didSearch := size0 != size1
2054 if didSearch != tt.wantEOFSearch {
2055 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2056 }
2057 if tt.wantNextReq && numReqs != 2 {
2058 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2059 }
2060 }
2061
2062
2063
2064 type testHandlerBodyConsumer struct {
2065 name string
2066 f func(io.ReadCloser)
2067 }
2068
2069 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2070 {"nil", func(io.ReadCloser) {}},
2071 {"close", func(r io.ReadCloser) { r.Close() }},
2072 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2073 }
2074
2075 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2076 setParallel(t)
2077 defer afterTest(t)
2078 for _, handler := range testHandlerBodyConsumers {
2079 conn := new(testConn)
2080 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2081 "Host: test\r\n" +
2082 "Transfer-Encoding: chunked\r\n" +
2083 "\r\n" +
2084 "hax\r\n" +
2085 "GET /secret HTTP/1.1\r\n" +
2086 "Host: test\r\n" +
2087 "\r\n")
2088
2089 conn.closec = make(chan bool, 1)
2090 ls := &oneConnListener{conn}
2091 var numReqs int
2092 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2093 numReqs++
2094 if strings.Contains(req.URL.Path, "secret") {
2095 t.Error("Request for /secret encountered, should not have happened.")
2096 }
2097 handler.f(req.Body)
2098 }))
2099 <-conn.closec
2100 if numReqs != 1 {
2101 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2102 }
2103 }
2104 }
2105
2106 func TestInvalidTrailerClosesConnection(t *testing.T) {
2107 setParallel(t)
2108 defer afterTest(t)
2109 for _, handler := range testHandlerBodyConsumers {
2110 conn := new(testConn)
2111 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2112 "Host: test\r\n" +
2113 "Trailer: hack\r\n" +
2114 "Transfer-Encoding: chunked\r\n" +
2115 "\r\n" +
2116 "3\r\n" +
2117 "hax\r\n" +
2118 "0\r\n" +
2119 "I'm not a valid trailer\r\n" +
2120 "GET /secret HTTP/1.1\r\n" +
2121 "Host: test\r\n" +
2122 "\r\n")
2123
2124 conn.closec = make(chan bool, 1)
2125 ln := &oneConnListener{conn}
2126 var numReqs int
2127 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2128 numReqs++
2129 if strings.Contains(req.URL.Path, "secret") {
2130 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2131 }
2132 handler.f(req.Body)
2133 }))
2134 <-conn.closec
2135 if numReqs != 1 {
2136 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2137 }
2138 }
2139 }
2140
2141
2142
2143
2144 type slowTestConn struct {
2145
2146 script []any
2147 closec chan bool
2148
2149 mu sync.Mutex
2150 rd, wd time.Time
2151 noopConn
2152 }
2153
2154 func (c *slowTestConn) SetDeadline(t time.Time) error {
2155 c.SetReadDeadline(t)
2156 c.SetWriteDeadline(t)
2157 return nil
2158 }
2159
2160 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2161 c.mu.Lock()
2162 defer c.mu.Unlock()
2163 c.rd = t
2164 return nil
2165 }
2166
2167 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2168 c.mu.Lock()
2169 defer c.mu.Unlock()
2170 c.wd = t
2171 return nil
2172 }
2173
2174 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2175 c.mu.Lock()
2176 defer c.mu.Unlock()
2177 restart:
2178 if !c.rd.IsZero() && time.Now().After(c.rd) {
2179 return 0, syscall.ETIMEDOUT
2180 }
2181 if len(c.script) == 0 {
2182 return 0, io.EOF
2183 }
2184
2185 switch cue := c.script[0].(type) {
2186 case time.Duration:
2187 if !c.rd.IsZero() {
2188
2189
2190 if remaining := time.Until(c.rd); remaining < cue {
2191 c.script[0] = cue - remaining
2192 time.Sleep(remaining)
2193 return 0, syscall.ETIMEDOUT
2194 }
2195 }
2196 c.script = c.script[1:]
2197 time.Sleep(cue)
2198 goto restart
2199
2200 case string:
2201 n = copy(b, cue)
2202
2203 if len(cue) > n {
2204 c.script[0] = cue[n:]
2205 } else {
2206 c.script = c.script[1:]
2207 }
2208
2209 default:
2210 panic("unknown cue in slowTestConn script")
2211 }
2212
2213 return
2214 }
2215
2216 func (c *slowTestConn) Close() error {
2217 select {
2218 case c.closec <- true:
2219 default:
2220 }
2221 return nil
2222 }
2223
2224 func (c *slowTestConn) Write(b []byte) (int, error) {
2225 if !c.wd.IsZero() && time.Now().After(c.wd) {
2226 return 0, syscall.ETIMEDOUT
2227 }
2228 return len(b), nil
2229 }
2230
2231 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2232 if testing.Short() {
2233 t.Skip("skipping in -short mode")
2234 }
2235 defer afterTest(t)
2236 for _, handler := range testHandlerBodyConsumers {
2237 conn := &slowTestConn{
2238 script: []any{
2239 "POST /public HTTP/1.1\r\n" +
2240 "Host: test\r\n" +
2241 "Content-Length: 10000\r\n" +
2242 "\r\n",
2243 "foo bar baz",
2244 600 * time.Millisecond,
2245 "GET /secret HTTP/1.1\r\n" +
2246 "Host: test\r\n" +
2247 "\r\n",
2248 },
2249 closec: make(chan bool, 1),
2250 }
2251 ls := &oneConnListener{conn}
2252
2253 var numReqs int
2254 s := Server{
2255 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2256 numReqs++
2257 if strings.Contains(req.URL.Path, "secret") {
2258 t.Error("Request for /secret encountered, should not have happened.")
2259 }
2260 handler.f(req.Body)
2261 }),
2262 ReadTimeout: 400 * time.Millisecond,
2263 }
2264 go s.Serve(ls)
2265 <-conn.closec
2266
2267 if numReqs != 1 {
2268 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2269 }
2270 }
2271 }
2272
2273
2274 type cancelableTimeoutContext struct {
2275 context.Context
2276 }
2277
2278 func (c cancelableTimeoutContext) Err() error {
2279 if c.Context.Err() != nil {
2280 return context.DeadlineExceeded
2281 }
2282 return nil
2283 }
2284
2285 func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
2286 func testTimeoutHandler(t *testing.T, mode testMode) {
2287 sendHi := make(chan bool, 1)
2288 writeErrors := make(chan error, 1)
2289 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2290 <-sendHi
2291 _, werr := w.Write([]byte("hi"))
2292 writeErrors <- werr
2293 })
2294 ctx, cancel := context.WithCancel(context.Background())
2295 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2296 cst := newClientServerTest(t, mode, h)
2297
2298
2299 sendHi <- true
2300 res, err := cst.c.Get(cst.ts.URL)
2301 if err != nil {
2302 t.Error(err)
2303 }
2304 if g, e := res.StatusCode, StatusOK; g != e {
2305 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2306 }
2307 body, _ := io.ReadAll(res.Body)
2308 if g, e := string(body), "hi"; g != e {
2309 t.Errorf("got body %q; expected %q", g, e)
2310 }
2311 if g := <-writeErrors; g != nil {
2312 t.Errorf("got unexpected Write error on first request: %v", g)
2313 }
2314
2315
2316 cancel()
2317
2318 res, err = cst.c.Get(cst.ts.URL)
2319 if err != nil {
2320 t.Error(err)
2321 }
2322 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2323 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2324 }
2325 body, _ = io.ReadAll(res.Body)
2326 if !strings.Contains(string(body), "<title>Timeout</title>") {
2327 t.Errorf("expected timeout body; got %q", string(body))
2328 }
2329 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2330 t.Errorf("response content-type = %q; want %q", g, w)
2331 }
2332
2333
2334
2335 sendHi <- true
2336 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2337 t.Errorf("expected Write error of %v; got %v", e, g)
2338 }
2339 }
2340
2341
2342 func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
2343 func testTimeoutHandlerRace(t *testing.T, mode testMode) {
2344 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2345 ms, _ := strconv.Atoi(r.URL.Path[1:])
2346 if ms == 0 {
2347 ms = 1
2348 }
2349 for i := 0; i < ms; i++ {
2350 w.Write([]byte("hi"))
2351 time.Sleep(time.Millisecond)
2352 }
2353 })
2354
2355 ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
2356
2357 c := ts.Client()
2358
2359 var wg sync.WaitGroup
2360 gate := make(chan bool, 10)
2361 n := 50
2362 if testing.Short() {
2363 n = 10
2364 gate = make(chan bool, 3)
2365 }
2366 for i := 0; i < n; i++ {
2367 gate <- true
2368 wg.Add(1)
2369 go func() {
2370 defer wg.Done()
2371 defer func() { <-gate }()
2372 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2373 if err == nil {
2374 io.Copy(io.Discard, res.Body)
2375 res.Body.Close()
2376 }
2377 }()
2378 }
2379 wg.Wait()
2380 }
2381
2382
2383
2384 func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
2385 func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
2386 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2387 w.WriteHeader(204)
2388 })
2389
2390 ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
2391
2392 var wg sync.WaitGroup
2393 gate := make(chan bool, 50)
2394 n := 500
2395 if testing.Short() {
2396 n = 10
2397 }
2398
2399 c := ts.Client()
2400 for i := 0; i < n; i++ {
2401 gate <- true
2402 wg.Add(1)
2403 go func() {
2404 defer wg.Done()
2405 defer func() { <-gate }()
2406 res, err := c.Get(ts.URL)
2407 if err != nil {
2408
2409
2410 t.Log(err)
2411 return
2412 }
2413 defer res.Body.Close()
2414 io.Copy(io.Discard, res.Body)
2415 }()
2416 }
2417 wg.Wait()
2418 }
2419
2420
2421 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
2422 func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
2423 sendHi := make(chan bool, 1)
2424 writeErrors := make(chan error, 1)
2425 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2426 w.Header().Set("Content-Type", "text/plain")
2427 <-sendHi
2428 _, werr := w.Write([]byte("hi"))
2429 writeErrors <- werr
2430 })
2431 ctx, cancel := context.WithCancel(context.Background())
2432 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2433 cst := newClientServerTest(t, mode, h)
2434
2435
2436 sendHi <- true
2437 res, err := cst.c.Get(cst.ts.URL)
2438 if err != nil {
2439 t.Error(err)
2440 }
2441 if g, e := res.StatusCode, StatusOK; g != e {
2442 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2443 }
2444 body, _ := io.ReadAll(res.Body)
2445 if g, e := string(body), "hi"; g != e {
2446 t.Errorf("got body %q; expected %q", g, e)
2447 }
2448 if g := <-writeErrors; g != nil {
2449 t.Errorf("got unexpected Write error on first request: %v", g)
2450 }
2451
2452
2453 cancel()
2454
2455 res, err = cst.c.Get(cst.ts.URL)
2456 if err != nil {
2457 t.Error(err)
2458 }
2459 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2460 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2461 }
2462 body, _ = io.ReadAll(res.Body)
2463 if !strings.Contains(string(body), "<title>Timeout</title>") {
2464 t.Errorf("expected timeout body; got %q", string(body))
2465 }
2466
2467
2468
2469 sendHi <- true
2470 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2471 t.Errorf("expected Write error of %v; got %v", e, g)
2472 }
2473 }
2474
2475
2476 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2477 run(t, testTimeoutHandlerStartTimerWhenServing)
2478 }
2479 func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
2480 if testing.Short() {
2481 t.Skip("skipping sleeping test in -short mode")
2482 }
2483 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2484 w.WriteHeader(StatusNoContent)
2485 }
2486 timeout := 300 * time.Millisecond
2487 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2488 defer ts.Close()
2489
2490 c := ts.Client()
2491
2492
2493
2494
2495 time.Sleep(2 * timeout)
2496 res, err := c.Get(ts.URL)
2497 if err != nil {
2498 t.Fatal(err)
2499 }
2500 defer res.Body.Close()
2501 if res.StatusCode != StatusNoContent {
2502 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2503 }
2504 }
2505
2506 func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
2507 func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
2508 writeErrors := make(chan error, 1)
2509 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2510 w.Header().Set("Content-Type", "text/plain")
2511 var err error
2512
2513
2514
2515 for i := 0; i < 100; i++ {
2516 _, err = w.Write([]byte("a"))
2517 if err != nil {
2518 break
2519 }
2520 time.Sleep(1 * time.Millisecond)
2521 }
2522 writeErrors <- err
2523 })
2524 ctx, cancel := context.WithCancel(context.Background())
2525 cancel()
2526 h := NewTestTimeoutHandler(sayHi, ctx)
2527 cst := newClientServerTest(t, mode, h)
2528 defer cst.close()
2529
2530 res, err := cst.c.Get(cst.ts.URL)
2531 if err != nil {
2532 t.Error(err)
2533 }
2534 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2535 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2536 }
2537 body, _ := io.ReadAll(res.Body)
2538 if g, e := string(body), ""; g != e {
2539 t.Errorf("got body %q; expected %q", g, e)
2540 }
2541 if g, e := <-writeErrors, context.Canceled; g != e {
2542 t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
2543 }
2544 }
2545
2546
2547 func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
2548 func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
2549 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2550
2551 }
2552 timeout := 300 * time.Millisecond
2553 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2554
2555 c := ts.Client()
2556
2557 res, err := c.Get(ts.URL)
2558 if err != nil {
2559 t.Fatal(err)
2560 }
2561 defer res.Body.Close()
2562 if res.StatusCode != StatusOK {
2563 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2564 }
2565 }
2566
2567
2568 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2569 wrapper := func(h Handler) Handler {
2570 return TimeoutHandler(h, time.Second, "")
2571 }
2572 run(t, func(t *testing.T, mode testMode) {
2573 testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
2574 }, testNotParallel)
2575 }
2576
2577 func TestRedirectBadPath(t *testing.T) {
2578
2579
2580 rr := httptest.NewRecorder()
2581 req := &Request{
2582 Method: "GET",
2583 URL: &url.URL{
2584 Scheme: "http",
2585 Path: "not-empty-but-no-leading-slash",
2586 },
2587 }
2588 Redirect(rr, req, "", 304)
2589 if rr.Code != 304 {
2590 t.Errorf("Code = %d; want 304", rr.Code)
2591 }
2592 }
2593
2594
2595 func TestRedirect(t *testing.T) {
2596 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2597
2598 var tests = []struct {
2599 in string
2600 want string
2601 }{
2602
2603 {"http://foobar.com/baz", "http://foobar.com/baz"},
2604
2605 {"https://foobar.com/baz", "https://foobar.com/baz"},
2606
2607 {"test://foobar.com/baz", "test://foobar.com/baz"},
2608
2609 {"//foobar.com/baz", "//foobar.com/baz"},
2610
2611 {"/foobar.com/baz", "/foobar.com/baz"},
2612
2613 {"foobar.com/baz", "/qux/foobar.com/baz"},
2614
2615 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2616
2617 {"///foobar.com/baz", "/foobar.com/baz"},
2618
2619
2620 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2621 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2622 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2623
2624 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2625 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2626 }
2627
2628 for _, tt := range tests {
2629 rec := httptest.NewRecorder()
2630 Redirect(rec, req, tt.in, 302)
2631 if got, want := rec.Code, 302; got != want {
2632 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
2633 }
2634 if got := rec.Header().Get("Location"); got != tt.want {
2635 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
2636 }
2637 }
2638 }
2639
2640
2641
2642 func TestRedirectContentTypeAndBody(t *testing.T) {
2643 type ctHeader struct {
2644 Values []string
2645 }
2646
2647 var tests = []struct {
2648 method string
2649 ct *ctHeader
2650 wantCT string
2651 wantBody string
2652 }{
2653 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
2654 {MethodHead, nil, "text/html; charset=utf-8", ""},
2655 {MethodPost, nil, "", ""},
2656 {MethodDelete, nil, "", ""},
2657 {"foo", nil, "", ""},
2658 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
2659 {MethodGet, &ctHeader{[]string{}}, "", ""},
2660 {MethodGet, &ctHeader{nil}, "", ""},
2661 }
2662 for _, tt := range tests {
2663 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
2664 rec := httptest.NewRecorder()
2665 if tt.ct != nil {
2666 rec.Header()["Content-Type"] = tt.ct.Values
2667 }
2668 Redirect(rec, req, "/foo", 302)
2669 if got, want := rec.Code, 302; got != want {
2670 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
2671 }
2672 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
2673 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
2674 }
2675 resp := rec.Result()
2676 body, err := io.ReadAll(resp.Body)
2677 if err != nil {
2678 t.Fatal(err)
2679 }
2680 if got, want := string(body), tt.wantBody; got != want {
2681 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
2682 }
2683 }
2684 }
2685
2686
2687
2688
2689
2690
2691
2692 func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
2693
2694 func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
2695 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
2696 all, err := io.ReadAll(r.Body)
2697 if err != nil {
2698 t.Fatalf("handler ReadAll: %v", err)
2699 }
2700 if len(all) != 0 {
2701 t.Errorf("handler got %d bytes; expected 0", len(all))
2702 }
2703 rw.Header().Set("Content-Length", "0")
2704 }))
2705
2706 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
2707 if err != nil {
2708 t.Fatal(err)
2709 }
2710 req.ContentLength = 0
2711
2712 var resp [5]*Response
2713 for i := range resp {
2714 resp[i], err = cst.c.Do(req)
2715 if err != nil {
2716 t.Fatalf("client post #%d: %v", i, err)
2717 }
2718 }
2719
2720 for i := range resp {
2721 all, err := io.ReadAll(resp[i].Body)
2722 if err != nil {
2723 t.Fatalf("req #%d: client ReadAll: %v", i, err)
2724 }
2725 if len(all) != 0 {
2726 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
2727 }
2728 }
2729 }
2730
2731 func TestHandlerPanicNil(t *testing.T) {
2732 run(t, func(t *testing.T, mode testMode) {
2733 testHandlerPanic(t, false, mode, nil, nil)
2734 }, testNotParallel)
2735 }
2736
2737 func TestHandlerPanic(t *testing.T) {
2738 run(t, func(t *testing.T, mode testMode) {
2739 testHandlerPanic(t, false, mode, nil, "intentional death for testing")
2740 }, testNotParallel)
2741 }
2742
2743 func TestHandlerPanicWithHijack(t *testing.T) {
2744
2745 run(t, func(t *testing.T, mode testMode) {
2746 testHandlerPanic(t, true, mode, nil, "intentional death for testing")
2747 }, []testMode{http1Mode})
2748 }
2749
2750 func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
2751
2752
2753
2754
2755
2756
2757
2758
2759 pr, pw := io.Pipe()
2760 defer pw.Close()
2761
2762 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
2763 if withHijack {
2764 rwc, _, err := w.(Hijacker).Hijack()
2765 if err != nil {
2766 t.Logf("unexpected error: %v", err)
2767 }
2768 defer rwc.Close()
2769 }
2770 panic(panicValue)
2771 })
2772 if wrapper != nil {
2773 handler = wrapper(handler)
2774 }
2775 cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) {
2776 ts.Config.ErrorLog = log.New(pw, "", 0)
2777 })
2778
2779
2780 done := make(chan bool, 1)
2781 go func() {
2782 buf := make([]byte, 4<<10)
2783 _, err := pr.Read(buf)
2784 pr.Close()
2785 if err != nil && err != io.EOF {
2786 t.Error(err)
2787 }
2788 done <- true
2789 }()
2790
2791 _, err := cst.c.Get(cst.ts.URL)
2792 if err == nil {
2793 t.Logf("expected an error")
2794 }
2795
2796 if panicValue == nil {
2797 return
2798 }
2799
2800 <-done
2801 }
2802
2803 type terrorWriter struct{ t *testing.T }
2804
2805 func (w terrorWriter) Write(p []byte) (int, error) {
2806 w.t.Errorf("%s", p)
2807 return len(p), nil
2808 }
2809
2810
2811
2812 func TestServerWriteHijackZeroBytes(t *testing.T) {
2813 run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
2814 }
2815 func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
2816 done := make(chan struct{})
2817 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2818 defer close(done)
2819 w.(Flusher).Flush()
2820 conn, _, err := w.(Hijacker).Hijack()
2821 if err != nil {
2822 t.Errorf("Hijack: %v", err)
2823 return
2824 }
2825 defer conn.Close()
2826 _, err = w.Write(nil)
2827 if err != ErrHijacked {
2828 t.Errorf("Write error = %v; want ErrHijacked", err)
2829 }
2830 }), func(ts *httptest.Server) {
2831 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
2832 }).ts
2833
2834 c := ts.Client()
2835 res, err := c.Get(ts.URL)
2836 if err != nil {
2837 t.Fatal(err)
2838 }
2839 res.Body.Close()
2840 <-done
2841 }
2842
2843 func TestServerNoDate(t *testing.T) {
2844 run(t, func(t *testing.T, mode testMode) {
2845 testServerNoHeader(t, mode, "Date")
2846 })
2847 }
2848
2849 func TestServerContentType(t *testing.T) {
2850 run(t, func(t *testing.T, mode testMode) {
2851 testServerNoHeader(t, mode, "Content-Type")
2852 })
2853 }
2854
2855 func testServerNoHeader(t *testing.T, mode testMode, header string) {
2856 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2857 w.Header()[header] = nil
2858 io.WriteString(w, "<html>foo</html>")
2859 }))
2860 res, err := cst.c.Get(cst.ts.URL)
2861 if err != nil {
2862 t.Fatal(err)
2863 }
2864 res.Body.Close()
2865 if got, ok := res.Header[header]; ok {
2866 t.Fatalf("Expected no %s header; got %q", header, got)
2867 }
2868 }
2869
2870 func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
2871 func testStripPrefix(t *testing.T, mode testMode) {
2872 h := HandlerFunc(func(w ResponseWriter, r *Request) {
2873 w.Header().Set("X-Path", r.URL.Path)
2874 w.Header().Set("X-RawPath", r.URL.RawPath)
2875 })
2876 ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
2877
2878 c := ts.Client()
2879
2880 cases := []struct {
2881 reqPath string
2882 path string
2883 rawPath string
2884 }{
2885 {"/foo/bar/qux", "/qux", ""},
2886 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
2887 {"/foo%2Fbar/qux", "", ""},
2888 {"/bar", "", ""},
2889 }
2890 for _, tc := range cases {
2891 t.Run(tc.reqPath, func(t *testing.T) {
2892 res, err := c.Get(ts.URL + tc.reqPath)
2893 if err != nil {
2894 t.Fatal(err)
2895 }
2896 res.Body.Close()
2897 if tc.path == "" {
2898 if res.StatusCode != StatusNotFound {
2899 t.Errorf("got %q, want 404 Not Found", res.Status)
2900 }
2901 return
2902 }
2903 if res.StatusCode != StatusOK {
2904 t.Fatalf("got %q, want 200 OK", res.Status)
2905 }
2906 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
2907 t.Errorf("got Path %q, want %q", g, w)
2908 }
2909 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
2910 t.Errorf("got RawPath %q, want %q", g, w)
2911 }
2912 })
2913 }
2914 }
2915
2916
2917 func TestStripPrefixNotModifyRequest(t *testing.T) {
2918 h := StripPrefix("/foo", NotFoundHandler())
2919 req := httptest.NewRequest("GET", "/foo/bar", nil)
2920 h.ServeHTTP(httptest.NewRecorder(), req)
2921 if req.URL.Path != "/foo/bar" {
2922 t.Errorf("StripPrefix should not modify the provided Request, but it did")
2923 }
2924 }
2925
2926 func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) }
2927 func testRequestLimit(t *testing.T, mode testMode) {
2928 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2929 t.Fatalf("didn't expect to get request in Handler")
2930 }), optQuietLog)
2931 req, _ := NewRequest("GET", cst.ts.URL, nil)
2932 var bytesPerHeader = len("header12345: val12345\r\n")
2933 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
2934 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
2935 }
2936 res, err := cst.c.Do(req)
2937 if res != nil {
2938 defer res.Body.Close()
2939 }
2940 if mode == http2Mode {
2941
2942
2943
2944
2945 if err == nil && res.StatusCode != 431 {
2946 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
2947 }
2948 } else {
2949
2950
2951
2952
2953 if err != nil {
2954 t.Fatalf("Do: %v", err)
2955 }
2956 if res.StatusCode != 431 {
2957 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
2958 }
2959 }
2960 }
2961
2962 type neverEnding byte
2963
2964 func (b neverEnding) Read(p []byte) (n int, err error) {
2965 for i := range p {
2966 p[i] = byte(b)
2967 }
2968 return len(p), nil
2969 }
2970
2971 type countReader struct {
2972 r io.Reader
2973 n *int64
2974 }
2975
2976 func (cr countReader) Read(p []byte) (n int, err error) {
2977 n, err = cr.r.Read(p)
2978 atomic.AddInt64(cr.n, int64(n))
2979 return
2980 }
2981
2982 func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
2983 func testRequestBodyLimit(t *testing.T, mode testMode) {
2984 const limit = 1 << 20
2985 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2986 r.Body = MaxBytesReader(w, r.Body, limit)
2987 n, err := io.Copy(io.Discard, r.Body)
2988 if err == nil {
2989 t.Errorf("expected error from io.Copy")
2990 }
2991 if n != limit {
2992 t.Errorf("io.Copy = %d, want %d", n, limit)
2993 }
2994 mbErr, ok := err.(*MaxBytesError)
2995 if !ok {
2996 t.Errorf("expected MaxBytesError, got %T", err)
2997 }
2998 if mbErr.Limit != limit {
2999 t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
3000 }
3001 }))
3002
3003 nWritten := new(int64)
3004 req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200))
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015 resp, err := cst.c.Do(req)
3016 if err == nil {
3017 resp.Body.Close()
3018 }
3019
3020 if atomic.LoadInt64(nWritten) > limit*100 {
3021 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
3022 limit, nWritten)
3023 }
3024 }
3025
3026
3027
3028 func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) }
3029 func testClientWriteShutdown(t *testing.T, mode testMode) {
3030 if runtime.GOOS == "plan9" {
3031 t.Skip("skipping test; see https://golang.org/issue/17906")
3032 }
3033 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
3034 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3035 if err != nil {
3036 t.Fatalf("Dial: %v", err)
3037 }
3038 err = conn.(*net.TCPConn).CloseWrite()
3039 if err != nil {
3040 t.Fatalf("CloseWrite: %v", err)
3041 }
3042
3043 bs, err := io.ReadAll(conn)
3044 if err != nil {
3045 t.Errorf("ReadAll: %v", err)
3046 }
3047 got := string(bs)
3048 if got != "" {
3049 t.Errorf("read %q from server; want nothing", got)
3050 }
3051 }
3052
3053
3054
3055 func TestServerBufferedChunking(t *testing.T) {
3056 conn := new(testConn)
3057 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3058 conn.closec = make(chan bool, 1)
3059 ls := &oneConnListener{conn}
3060 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3061 rw.(Flusher).Flush()
3062 rw.Write([]byte{'x'})
3063 rw.Write([]byte{'y'})
3064 rw.Write([]byte{'z'})
3065 }))
3066 <-conn.closec
3067 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3068 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3069 conn.writeBuf.Bytes())
3070 }
3071 }
3072
3073
3074
3075
3076
3077 func TestServerGracefulClose(t *testing.T) {
3078 run(t, testServerGracefulClose, []testMode{http1Mode})
3079 }
3080 func testServerGracefulClose(t *testing.T, mode testMode) {
3081 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3082 Error(w, "bye", StatusUnauthorized)
3083 })).ts
3084
3085 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3086 if err != nil {
3087 t.Fatal(err)
3088 }
3089 defer conn.Close()
3090 const bodySize = 5 << 20
3091 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3092 for i := 0; i < bodySize; i++ {
3093 req = append(req, 'x')
3094 }
3095 writeErr := make(chan error)
3096 go func() {
3097 _, err := conn.Write(req)
3098 writeErr <- err
3099 }()
3100 br := bufio.NewReader(conn)
3101 lineNum := 0
3102 for {
3103 line, err := br.ReadString('\n')
3104 if err == io.EOF {
3105 break
3106 }
3107 if err != nil {
3108 t.Fatalf("ReadLine: %v", err)
3109 }
3110 lineNum++
3111 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3112 t.Errorf("Response line = %q; want a 401", line)
3113 }
3114 }
3115
3116
3117
3118 <-writeErr
3119 }
3120
3121 func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
3122 func testCaseSensitiveMethod(t *testing.T, mode testMode) {
3123 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3124 if r.Method != "get" {
3125 t.Errorf(`Got method %q; want "get"`, r.Method)
3126 }
3127 }))
3128 defer cst.close()
3129 req, _ := NewRequest("get", cst.ts.URL, nil)
3130 res, err := cst.c.Do(req)
3131 if err != nil {
3132 t.Error(err)
3133 return
3134 }
3135
3136 res.Body.Close()
3137 }
3138
3139
3140
3141
3142
3143 func TestContentLengthZero(t *testing.T) {
3144 run(t, testContentLengthZero, []testMode{http1Mode})
3145 }
3146 func testContentLengthZero(t *testing.T, mode testMode) {
3147 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
3148
3149 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3150 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3151 if err != nil {
3152 t.Fatalf("error dialing: %v", err)
3153 }
3154 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3155 if err != nil {
3156 t.Fatalf("error writing: %v", err)
3157 }
3158 req, _ := NewRequest("GET", "/", nil)
3159 res, err := ReadResponse(bufio.NewReader(conn), req)
3160 if err != nil {
3161 t.Fatalf("error reading response: %v", err)
3162 }
3163 if te := res.TransferEncoding; len(te) > 0 {
3164 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3165 }
3166 if cl := res.ContentLength; cl != 0 {
3167 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3168 }
3169 conn.Close()
3170 }
3171 }
3172
3173 func TestCloseNotifier(t *testing.T) {
3174 run(t, testCloseNotifier, []testMode{http1Mode})
3175 }
3176 func testCloseNotifier(t *testing.T, mode testMode) {
3177 gotReq := make(chan bool, 1)
3178 sawClose := make(chan bool, 1)
3179 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3180 gotReq <- true
3181 cc := rw.(CloseNotifier).CloseNotify()
3182 <-cc
3183 sawClose <- true
3184 })).ts
3185 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3186 if err != nil {
3187 t.Fatalf("error dialing: %v", err)
3188 }
3189 diec := make(chan bool)
3190 go func() {
3191 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3192 if err != nil {
3193 t.Error(err)
3194 return
3195 }
3196 <-diec
3197 conn.Close()
3198 }()
3199 For:
3200 for {
3201 select {
3202 case <-gotReq:
3203 diec <- true
3204 case <-sawClose:
3205 break For
3206 }
3207 }
3208 ts.Close()
3209 }
3210
3211
3212
3213
3214
3215 func TestCloseNotifierPipelined(t *testing.T) {
3216 run(t, testCloseNotifierPipelined, []testMode{http1Mode})
3217 }
3218 func testCloseNotifierPipelined(t *testing.T, mode testMode) {
3219 gotReq := make(chan bool, 2)
3220 sawClose := make(chan bool, 2)
3221 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3222 gotReq <- true
3223 cc := rw.(CloseNotifier).CloseNotify()
3224 select {
3225 case <-cc:
3226 t.Error("unexpected CloseNotify")
3227 case <-time.After(100 * time.Millisecond):
3228 }
3229 sawClose <- true
3230 })).ts
3231 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3232 if err != nil {
3233 t.Fatalf("error dialing: %v", err)
3234 }
3235 diec := make(chan bool, 1)
3236 defer close(diec)
3237 go func() {
3238 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3239 _, err = io.WriteString(conn, req+req)
3240 if err != nil {
3241 t.Error(err)
3242 return
3243 }
3244 <-diec
3245 conn.Close()
3246 }()
3247 reqs := 0
3248 closes := 0
3249 for {
3250 select {
3251 case <-gotReq:
3252 reqs++
3253 if reqs > 2 {
3254 t.Fatal("too many requests")
3255 }
3256 case <-sawClose:
3257 closes++
3258 if closes > 1 {
3259 return
3260 }
3261 }
3262 }
3263 }
3264
3265 func TestCloseNotifierChanLeak(t *testing.T) {
3266 defer afterTest(t)
3267 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3268 for i := 0; i < 20; i++ {
3269 var output bytes.Buffer
3270 conn := &rwTestConn{
3271 Reader: bytes.NewReader(req),
3272 Writer: &output,
3273 closec: make(chan bool, 1),
3274 }
3275 ln := &oneConnListener{conn: conn}
3276 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3277
3278
3279
3280 _ = rw.(CloseNotifier).CloseNotify()
3281 })
3282 go Serve(ln, handler)
3283 <-conn.closec
3284 }
3285 }
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296 func TestHijackAfterCloseNotifier(t *testing.T) {
3297 run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
3298 }
3299 func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
3300 script := make(chan string, 2)
3301 script <- "closenotify"
3302 script <- "hijack"
3303 close(script)
3304 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3305 plan := <-script
3306 switch plan {
3307 default:
3308 panic("bogus plan; too many requests")
3309 case "closenotify":
3310 w.(CloseNotifier).CloseNotify()
3311 w.Header().Set("X-Addr", r.RemoteAddr)
3312 case "hijack":
3313 c, _, err := w.(Hijacker).Hijack()
3314 if err != nil {
3315 t.Errorf("Hijack in Handler: %v", err)
3316 return
3317 }
3318 if _, ok := c.(*net.TCPConn); !ok {
3319
3320
3321 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3322 }
3323 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3324 c.Close()
3325 return
3326 }
3327 })).ts
3328 res1, err := ts.Client().Get(ts.URL)
3329 if err != nil {
3330 log.Fatal(err)
3331 }
3332 res2, err := ts.Client().Get(ts.URL)
3333 if err != nil {
3334 log.Fatal(err)
3335 }
3336 addr1 := res1.Header.Get("X-Addr")
3337 addr2 := res2.Header.Get("X-Addr")
3338 if addr1 == "" || addr1 != addr2 {
3339 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3340 }
3341 }
3342
3343 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3344 run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
3345 }
3346 func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
3347 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3348 bodyOkay := make(chan bool, 1)
3349 gotCloseNotify := make(chan bool, 1)
3350 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3351 defer close(bodyOkay)
3352
3353 reqBody := r.Body
3354 r.Body = nil
3355
3356 gone := w.(CloseNotifier).CloseNotify()
3357 slurp, err := io.ReadAll(reqBody)
3358 if err != nil {
3359 t.Errorf("Body read: %v", err)
3360 return
3361 }
3362 if len(slurp) != len(requestBody) {
3363 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3364 return
3365 }
3366 if !bytes.Equal(slurp, requestBody) {
3367 t.Error("Backend read wrong request body.")
3368 return
3369 }
3370 bodyOkay <- true
3371 <-gone
3372 gotCloseNotify <- true
3373 })).ts
3374
3375 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3376 if err != nil {
3377 t.Fatal(err)
3378 }
3379 defer conn.Close()
3380
3381 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3382 len(requestBody), requestBody)
3383 if !<-bodyOkay {
3384
3385 return
3386 }
3387 conn.Close()
3388 <-gotCloseNotify
3389 }
3390
3391 func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
3392 func testOptions(t *testing.T, mode testMode) {
3393 uric := make(chan string, 2)
3394 mux := NewServeMux()
3395 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3396 uric <- r.RequestURI
3397 })
3398 ts := newClientServerTest(t, mode, mux).ts
3399
3400 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3401 if err != nil {
3402 t.Fatal(err)
3403 }
3404 defer conn.Close()
3405
3406
3407 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3408 if err != nil {
3409 t.Fatal(err)
3410 }
3411 br := bufio.NewReader(conn)
3412 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3413 if err != nil {
3414 t.Fatal(err)
3415 }
3416 if res.StatusCode != 200 {
3417 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3418 }
3419
3420
3421 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3422 if err != nil {
3423 t.Fatal(err)
3424 }
3425 res, err = ReadResponse(br, &Request{Method: "GET"})
3426 if err != nil {
3427 t.Fatal(err)
3428 }
3429 if res.StatusCode != 400 {
3430 t.Errorf("Got non-400 response to GET *: %#v", res)
3431 }
3432
3433 res, err = Get(ts.URL + "/second")
3434 if err != nil {
3435 t.Fatal(err)
3436 }
3437 res.Body.Close()
3438 if got := <-uric; got != "/second" {
3439 t.Errorf("Handler saw request for %q; want /second", got)
3440 }
3441 }
3442
3443 func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
3444 func testOptionsHandler(t *testing.T, mode testMode) {
3445 rc := make(chan *Request, 1)
3446
3447 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3448 rc <- r
3449 }), func(ts *httptest.Server) {
3450 ts.Config.DisableGeneralOptionsHandler = true
3451 }).ts
3452
3453 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3454 if err != nil {
3455 t.Fatal(err)
3456 }
3457 defer conn.Close()
3458
3459 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3460 if err != nil {
3461 t.Fatal(err)
3462 }
3463
3464 if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" {
3465 t.Errorf("Expected OPTIONS * request, got %v", got)
3466 }
3467 }
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478 func TestHeaderToWire(t *testing.T) {
3479 tests := []struct {
3480 name string
3481 handler func(ResponseWriter, *Request)
3482 check func(got, logs string) error
3483 }{
3484 {
3485 name: "write without Header",
3486 handler: func(rw ResponseWriter, r *Request) {
3487 rw.Write([]byte("hello world"))
3488 },
3489 check: func(got, logs string) error {
3490 if !strings.Contains(got, "Content-Length:") {
3491 return errors.New("no content-length")
3492 }
3493 if !strings.Contains(got, "Content-Type: text/plain") {
3494 return errors.New("no content-type")
3495 }
3496 return nil
3497 },
3498 },
3499 {
3500 name: "Header mutation before write",
3501 handler: func(rw ResponseWriter, r *Request) {
3502 h := rw.Header()
3503 h.Set("Content-Type", "some/type")
3504 rw.Write([]byte("hello world"))
3505 h.Set("Too-Late", "bogus")
3506 },
3507 check: func(got, logs string) error {
3508 if !strings.Contains(got, "Content-Length:") {
3509 return errors.New("no content-length")
3510 }
3511 if !strings.Contains(got, "Content-Type: some/type") {
3512 return errors.New("wrong content-type")
3513 }
3514 if strings.Contains(got, "Too-Late") {
3515 return errors.New("don't want too-late header")
3516 }
3517 return nil
3518 },
3519 },
3520 {
3521 name: "write then useless Header mutation",
3522 handler: func(rw ResponseWriter, r *Request) {
3523 rw.Write([]byte("hello world"))
3524 rw.Header().Set("Too-Late", "Write already wrote headers")
3525 },
3526 check: func(got, logs string) error {
3527 if strings.Contains(got, "Too-Late") {
3528 return errors.New("header appeared from after WriteHeader")
3529 }
3530 return nil
3531 },
3532 },
3533 {
3534 name: "flush then write",
3535 handler: func(rw ResponseWriter, r *Request) {
3536 rw.(Flusher).Flush()
3537 rw.Write([]byte("post-flush"))
3538 rw.Header().Set("Too-Late", "Write already wrote headers")
3539 },
3540 check: func(got, logs string) error {
3541 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3542 return errors.New("not chunked")
3543 }
3544 if strings.Contains(got, "Too-Late") {
3545 return errors.New("header appeared from after WriteHeader")
3546 }
3547 return nil
3548 },
3549 },
3550 {
3551 name: "header then flush",
3552 handler: func(rw ResponseWriter, r *Request) {
3553 rw.Header().Set("Content-Type", "some/type")
3554 rw.(Flusher).Flush()
3555 rw.Write([]byte("post-flush"))
3556 rw.Header().Set("Too-Late", "Write already wrote headers")
3557 },
3558 check: func(got, logs string) error {
3559 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3560 return errors.New("not chunked")
3561 }
3562 if strings.Contains(got, "Too-Late") {
3563 return errors.New("header appeared from after WriteHeader")
3564 }
3565 if !strings.Contains(got, "Content-Type: some/type") {
3566 return errors.New("wrong content-type")
3567 }
3568 return nil
3569 },
3570 },
3571 {
3572 name: "sniff-on-first-write content-type",
3573 handler: func(rw ResponseWriter, r *Request) {
3574 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3575 rw.Header().Set("Content-Type", "x/wrong")
3576 },
3577 check: func(got, logs string) error {
3578 if !strings.Contains(got, "Content-Type: text/html") {
3579 return errors.New("wrong content-type; want html")
3580 }
3581 return nil
3582 },
3583 },
3584 {
3585 name: "explicit content-type wins",
3586 handler: func(rw ResponseWriter, r *Request) {
3587 rw.Header().Set("Content-Type", "some/type")
3588 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3589 },
3590 check: func(got, logs string) error {
3591 if !strings.Contains(got, "Content-Type: some/type") {
3592 return errors.New("wrong content-type; want html")
3593 }
3594 return nil
3595 },
3596 },
3597 {
3598 name: "empty handler",
3599 handler: func(rw ResponseWriter, r *Request) {
3600 },
3601 check: func(got, logs string) error {
3602 if !strings.Contains(got, "Content-Length: 0") {
3603 return errors.New("want 0 content-length")
3604 }
3605 return nil
3606 },
3607 },
3608 {
3609 name: "only Header, no write",
3610 handler: func(rw ResponseWriter, r *Request) {
3611 rw.Header().Set("Some-Header", "some-value")
3612 },
3613 check: func(got, logs string) error {
3614 if !strings.Contains(got, "Some-Header") {
3615 return errors.New("didn't get header")
3616 }
3617 return nil
3618 },
3619 },
3620 {
3621 name: "WriteHeader call",
3622 handler: func(rw ResponseWriter, r *Request) {
3623 rw.WriteHeader(404)
3624 rw.Header().Set("Too-Late", "some-value")
3625 },
3626 check: func(got, logs string) error {
3627 if !strings.Contains(got, "404") {
3628 return errors.New("wrong status")
3629 }
3630 if strings.Contains(got, "Too-Late") {
3631 return errors.New("shouldn't have seen Too-Late")
3632 }
3633 return nil
3634 },
3635 },
3636 }
3637 for _, tc := range tests {
3638 ht := newHandlerTest(HandlerFunc(tc.handler))
3639 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
3640 logs := ht.logbuf.String()
3641 if err := tc.check(got, logs); err != nil {
3642 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
3643 }
3644 }
3645 }
3646
3647 type errorListener struct {
3648 errs []error
3649 }
3650
3651 func (l *errorListener) Accept() (c net.Conn, err error) {
3652 if len(l.errs) == 0 {
3653 return nil, io.EOF
3654 }
3655 err = l.errs[0]
3656 l.errs = l.errs[1:]
3657 return
3658 }
3659
3660 func (l *errorListener) Close() error {
3661 return nil
3662 }
3663
3664 func (l *errorListener) Addr() net.Addr {
3665 return dummyAddr("test-address")
3666 }
3667
3668 func TestAcceptMaxFds(t *testing.T) {
3669 setParallel(t)
3670
3671 ln := &errorListener{[]error{
3672 &net.OpError{
3673 Op: "accept",
3674 Err: syscall.EMFILE,
3675 }}}
3676 server := &Server{
3677 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
3678 ErrorLog: log.New(io.Discard, "", 0),
3679 }
3680 err := server.Serve(ln)
3681 if err != io.EOF {
3682 t.Errorf("got error %v, want EOF", err)
3683 }
3684 }
3685
3686 func TestWriteAfterHijack(t *testing.T) {
3687 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3688 var buf strings.Builder
3689 wrotec := make(chan bool, 1)
3690 conn := &rwTestConn{
3691 Reader: bytes.NewReader(req),
3692 Writer: &buf,
3693 closec: make(chan bool, 1),
3694 }
3695 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3696 conn, bufrw, err := rw.(Hijacker).Hijack()
3697 if err != nil {
3698 t.Error(err)
3699 return
3700 }
3701 go func() {
3702 bufrw.Write([]byte("[hijack-to-bufw]"))
3703 bufrw.Flush()
3704 conn.Write([]byte("[hijack-to-conn]"))
3705 conn.Close()
3706 wrotec <- true
3707 }()
3708 })
3709 ln := &oneConnListener{conn: conn}
3710 go Serve(ln, handler)
3711 <-conn.closec
3712 <-wrotec
3713 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
3714 t.Errorf("wrote %q; want %q", g, w)
3715 }
3716 }
3717
3718 func TestDoubleHijack(t *testing.T) {
3719 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3720 var buf bytes.Buffer
3721 conn := &rwTestConn{
3722 Reader: bytes.NewReader(req),
3723 Writer: &buf,
3724 closec: make(chan bool, 1),
3725 }
3726 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3727 conn, _, err := rw.(Hijacker).Hijack()
3728 if err != nil {
3729 t.Error(err)
3730 return
3731 }
3732 _, _, err = rw.(Hijacker).Hijack()
3733 if err == nil {
3734 t.Errorf("got err = nil; want err != nil")
3735 }
3736 conn.Close()
3737 })
3738 ln := &oneConnListener{conn: conn}
3739 go Serve(ln, handler)
3740 <-conn.closec
3741 }
3742
3743
3744
3745
3746
3747
3748
3749 func TestHTTP10ConnectionHeader(t *testing.T) {
3750 run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
3751 }
3752 func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
3753 mux := NewServeMux()
3754 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
3755 ts := newClientServerTest(t, mode, mux).ts
3756
3757
3758 tests := []struct {
3759 req string
3760 expect []string
3761 }{
3762 {
3763 req: "GET / HTTP/1.0\r\n\r\n",
3764 expect: nil,
3765 },
3766 {
3767 req: "OPTIONS * HTTP/1.0\r\n\r\n",
3768 expect: nil,
3769 },
3770 {
3771 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
3772 expect: []string{"keep-alive"},
3773 },
3774 }
3775
3776 for _, tt := range tests {
3777 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3778 if err != nil {
3779 t.Fatal("dial err:", err)
3780 }
3781
3782 _, err = fmt.Fprint(conn, tt.req)
3783 if err != nil {
3784 t.Fatal("conn write err:", err)
3785 }
3786
3787 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
3788 if err != nil {
3789 t.Fatal("ReadResponse err:", err)
3790 }
3791 conn.Close()
3792 resp.Body.Close()
3793
3794 got := resp.Header["Connection"]
3795 if !reflect.DeepEqual(got, tt.expect) {
3796 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
3797 }
3798 }
3799 }
3800
3801
3802 func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
3803 func testServerReaderFromOrder(t *testing.T, mode testMode) {
3804 pr, pw := io.Pipe()
3805 const size = 3 << 20
3806 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3807 rw.Header().Set("Content-Type", "text/plain")
3808 done := make(chan bool)
3809 go func() {
3810 io.Copy(rw, pr)
3811 close(done)
3812 }()
3813 time.Sleep(25 * time.Millisecond)
3814 n, err := io.Copy(io.Discard, req.Body)
3815 if err != nil {
3816 t.Errorf("handler Copy: %v", err)
3817 return
3818 }
3819 if n != size {
3820 t.Errorf("handler Copy = %d; want %d", n, size)
3821 }
3822 pw.Write([]byte("hi"))
3823 pw.Close()
3824 <-done
3825 }))
3826
3827 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
3828 if err != nil {
3829 t.Fatal(err)
3830 }
3831 res, err := cst.c.Do(req)
3832 if err != nil {
3833 t.Fatal(err)
3834 }
3835 all, err := io.ReadAll(res.Body)
3836 if err != nil {
3837 t.Fatal(err)
3838 }
3839 res.Body.Close()
3840 if string(all) != "hi" {
3841 t.Errorf("Body = %q; want hi", all)
3842 }
3843 }
3844
3845
3846 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
3847 for _, code := range []int{StatusNotModified, StatusNoContent} {
3848 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
3849 if r.URL.Path == "/header" {
3850 w.Header().Set("Content-Length", "123")
3851 }
3852 w.WriteHeader(code)
3853 if r.URL.Path == "/more" {
3854 w.Write([]byte("stuff"))
3855 }
3856 }))
3857 for _, req := range []string{
3858 "GET / HTTP/1.0",
3859 "GET /header HTTP/1.0",
3860 "GET /more HTTP/1.0",
3861 "GET / HTTP/1.1\nHost: foo",
3862 "GET /header HTTP/1.1\nHost: foo",
3863 "GET /more HTTP/1.1\nHost: foo",
3864 } {
3865 got := ht.rawResponse(req)
3866 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
3867 if !strings.Contains(got, wantStatus) {
3868 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
3869 } else if strings.Contains(got, "Content-Length") {
3870 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
3871 } else if strings.Contains(got, "stuff") {
3872 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
3873 }
3874 }
3875 }
3876 }
3877
3878 func TestContentTypeOkayOn204(t *testing.T) {
3879 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
3880 w.Header().Set("Content-Length", "123")
3881 w.Header().Set("Content-Type", "foo/bar")
3882 w.WriteHeader(204)
3883 }))
3884 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
3885 if !strings.Contains(got, "Content-Type: foo/bar") {
3886 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
3887 }
3888 if strings.Contains(got, "Content-Length: 123") {
3889 t.Errorf("Response = %q; don't want a Content-Length", got)
3890 }
3891 }
3892
3893
3894
3895
3896
3897
3898
3899 func TestTransportAndServerSharedBodyRace(t *testing.T) {
3900 run(t, testTransportAndServerSharedBodyRace)
3901 }
3902 func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
3903 const bodySize = 1 << 20
3904
3905
3906
3907
3908
3909 errorf := func(format string, args ...any) {
3910 v := fmt.Sprintf(format, args...)
3911 println(v)
3912 t.Error(v)
3913 }
3914
3915 unblockBackend := make(chan bool)
3916 backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3917 gone := rw.(CloseNotifier).CloseNotify()
3918 didCopy := make(chan any)
3919 go func() {
3920 n, err := io.CopyN(rw, req.Body, bodySize)
3921 didCopy <- []any{n, err}
3922 }()
3923 isGone := false
3924 Loop:
3925 for {
3926 select {
3927 case <-didCopy:
3928 break Loop
3929 case <-gone:
3930 isGone = true
3931 case <-time.After(time.Second):
3932 println("1 second passes in backend, proxygone=", isGone)
3933 }
3934 }
3935 <-unblockBackend
3936 }))
3937 defer backend.close()
3938
3939 backendRespc := make(chan *Response, 1)
3940 var proxy *clientServerTest
3941 proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3942 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
3943 req2.ContentLength = bodySize
3944 cancel := make(chan struct{})
3945 req2.Cancel = cancel
3946
3947 bresp, err := proxy.c.Do(req2)
3948 if err != nil {
3949 errorf("Proxy outbound request: %v", err)
3950 return
3951 }
3952 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
3953 if err != nil {
3954 errorf("Proxy copy error: %v", err)
3955 return
3956 }
3957 backendRespc <- bresp
3958
3959
3960
3961 if mode == http2Mode {
3962 close(cancel)
3963 } else {
3964 proxy.c.Transport.(*Transport).CancelRequest(req2)
3965 }
3966 rw.Write([]byte("OK"))
3967 }))
3968 defer proxy.close()
3969
3970 defer close(unblockBackend)
3971 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
3972 res, err := proxy.c.Do(req)
3973 if err != nil {
3974 t.Fatalf("Original request: %v", err)
3975 }
3976
3977
3978 res.Body.Close()
3979 select {
3980 case res := <-backendRespc:
3981 res.Body.Close()
3982 default:
3983
3984 }
3985 }
3986
3987
3988
3989
3990 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
3991 run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
3992 }
3993 func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
3994 if testing.Short() {
3995 t.Skip("skipping in -short mode")
3996 }
3997
3998 readErrCh := make(chan error, 1)
3999 errCh := make(chan error, 2)
4000
4001 server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4002 go func(body io.Reader) {
4003 _, err := body.Read(make([]byte, 100))
4004 readErrCh <- err
4005 }(req.Body)
4006 time.Sleep(500 * time.Millisecond)
4007 })).ts
4008
4009 closeConn := make(chan bool)
4010 defer close(closeConn)
4011 go func() {
4012 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4013 if err != nil {
4014 errCh <- err
4015 return
4016 }
4017 defer conn.Close()
4018 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4019 if err != nil {
4020 errCh <- err
4021 return
4022 }
4023
4024
4025 <-closeConn
4026 }()
4027 select {
4028 case err := <-readErrCh:
4029 if err == nil {
4030 t.Error("Read was nil. Expected error.")
4031 }
4032 case err := <-errCh:
4033 t.Error(err)
4034 }
4035 }
4036
4037
4038 func TestResponseWriterWriteString(t *testing.T) {
4039 okc := make(chan bool, 1)
4040 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4041 _, ok := w.(io.StringWriter)
4042 okc <- ok
4043 }))
4044 ht.rawResponse("GET / HTTP/1.0")
4045 select {
4046 case ok := <-okc:
4047 if !ok {
4048 t.Error("ResponseWriter did not implement io.StringWriter")
4049 }
4050 default:
4051 t.Error("handler was never called")
4052 }
4053 }
4054
4055 func TestAppendTime(t *testing.T) {
4056 var b [len(TimeFormat)]byte
4057 t1 := time.Date(2013, 9, 21, 15, 41, 0, 0, time.FixedZone("CEST", 2*60*60))
4058 res := ExportAppendTime(b[:0], t1)
4059 t2, err := ParseTime(string(res))
4060 if err != nil {
4061 t.Fatalf("Error parsing time: %s", err)
4062 }
4063 if !t1.Equal(t2) {
4064 t.Fatalf("Times differ; expected: %v, got %v (%s)", t1, t2, string(res))
4065 }
4066 }
4067
4068 func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
4069 func testServerConnState(t *testing.T, mode testMode) {
4070 handler := map[string]func(w ResponseWriter, r *Request){
4071 "/": func(w ResponseWriter, r *Request) {
4072 fmt.Fprintf(w, "Hello.")
4073 },
4074 "/close": func(w ResponseWriter, r *Request) {
4075 w.Header().Set("Connection", "close")
4076 fmt.Fprintf(w, "Hello.")
4077 },
4078 "/hijack": func(w ResponseWriter, r *Request) {
4079 c, _, _ := w.(Hijacker).Hijack()
4080 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4081 c.Close()
4082 },
4083 "/hijack-panic": func(w ResponseWriter, r *Request) {
4084 c, _, _ := w.(Hijacker).Hijack()
4085 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4086 c.Close()
4087 panic("intentional panic")
4088 },
4089 }
4090
4091
4092 type stateLog struct {
4093 active net.Conn
4094 got []ConnState
4095 want []ConnState
4096 complete chan<- struct{}
4097 }
4098 activeLog := make(chan *stateLog, 1)
4099
4100
4101
4102
4103 wantLog := func(doRequests func(), want ...ConnState) {
4104 t.Helper()
4105 complete := make(chan struct{})
4106 activeLog <- &stateLog{want: want, complete: complete}
4107
4108 doRequests()
4109
4110 <-complete
4111 sl := <-activeLog
4112 if !reflect.DeepEqual(sl.got, sl.want) {
4113 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4114 }
4115
4116
4117
4118 }
4119
4120 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4121 handler[r.URL.Path](w, r)
4122 }), func(ts *httptest.Server) {
4123 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4124 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4125 if c == nil {
4126 t.Errorf("nil conn seen in state %s", state)
4127 return
4128 }
4129 sl := <-activeLog
4130 if sl.active == nil && state == StateNew {
4131 sl.active = c
4132 } else if sl.active != c {
4133 t.Errorf("unexpected conn in state %s", state)
4134 activeLog <- sl
4135 return
4136 }
4137 sl.got = append(sl.got, state)
4138 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) {
4139 close(sl.complete)
4140 sl.complete = nil
4141 }
4142 activeLog <- sl
4143 }
4144 }).ts
4145 defer func() {
4146 activeLog <- &stateLog{}
4147 ts.Close()
4148 }()
4149
4150 c := ts.Client()
4151
4152 mustGet := func(url string, headers ...string) {
4153 t.Helper()
4154 req, err := NewRequest("GET", url, nil)
4155 if err != nil {
4156 t.Fatal(err)
4157 }
4158 for len(headers) > 0 {
4159 req.Header.Add(headers[0], headers[1])
4160 headers = headers[2:]
4161 }
4162 res, err := c.Do(req)
4163 if err != nil {
4164 t.Errorf("Error fetching %s: %v", url, err)
4165 return
4166 }
4167 _, err = io.ReadAll(res.Body)
4168 defer res.Body.Close()
4169 if err != nil {
4170 t.Errorf("Error reading %s: %v", url, err)
4171 }
4172 }
4173
4174 wantLog(func() {
4175 mustGet(ts.URL + "/")
4176 mustGet(ts.URL + "/close")
4177 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4178
4179 wantLog(func() {
4180 mustGet(ts.URL + "/")
4181 mustGet(ts.URL+"/", "Connection", "close")
4182 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4183
4184 wantLog(func() {
4185 mustGet(ts.URL + "/hijack")
4186 }, StateNew, StateActive, StateHijacked)
4187
4188 wantLog(func() {
4189 mustGet(ts.URL + "/hijack-panic")
4190 }, StateNew, StateActive, StateHijacked)
4191
4192 wantLog(func() {
4193 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4194 if err != nil {
4195 t.Fatal(err)
4196 }
4197 c.Close()
4198 }, StateNew, StateClosed)
4199
4200 wantLog(func() {
4201 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4202 if err != nil {
4203 t.Fatal(err)
4204 }
4205 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4206 t.Fatal(err)
4207 }
4208 c.Read(make([]byte, 1))
4209 c.Close()
4210 }, StateNew, StateActive, StateClosed)
4211
4212 wantLog(func() {
4213 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4214 if err != nil {
4215 t.Fatal(err)
4216 }
4217 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4218 t.Fatal(err)
4219 }
4220 res, err := ReadResponse(bufio.NewReader(c), nil)
4221 if err != nil {
4222 t.Fatal(err)
4223 }
4224 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4225 t.Fatal(err)
4226 }
4227 c.Close()
4228 }, StateNew, StateActive, StateIdle, StateClosed)
4229 }
4230
4231 func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
4232 run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
4233 }
4234 func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
4235 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4236 }), func(ts *httptest.Server) {
4237 ts.Config.SetKeepAlivesEnabled(false)
4238 }).ts
4239 res, err := ts.Client().Get(ts.URL)
4240 if err != nil {
4241 t.Fatal(err)
4242 }
4243 defer res.Body.Close()
4244 if !res.Close {
4245 t.Errorf("Body.Close == false; want true")
4246 }
4247 }
4248
4249
4250 func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
4251 func testServerEmptyBodyRace(t *testing.T, mode testMode) {
4252 var n int32
4253 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4254 atomic.AddInt32(&n, 1)
4255 }), optQuietLog)
4256 var wg sync.WaitGroup
4257 const reqs = 20
4258 for i := 0; i < reqs; i++ {
4259 wg.Add(1)
4260 go func() {
4261 defer wg.Done()
4262 res, err := cst.c.Get(cst.ts.URL)
4263 if err != nil {
4264
4265
4266 time.Sleep(10 * time.Millisecond)
4267 res, err = cst.c.Get(cst.ts.URL)
4268 if err != nil {
4269 t.Error(err)
4270 return
4271 }
4272 }
4273 defer res.Body.Close()
4274 _, err = io.Copy(io.Discard, res.Body)
4275 if err != nil {
4276 t.Error(err)
4277 return
4278 }
4279 }()
4280 }
4281 wg.Wait()
4282 if got := atomic.LoadInt32(&n); got != reqs {
4283 t.Errorf("handler ran %d times; want %d", got, reqs)
4284 }
4285 }
4286
4287 func TestServerConnStateNew(t *testing.T) {
4288 sawNew := false
4289 srv := &Server{
4290 ConnState: func(c net.Conn, state ConnState) {
4291 if state == StateNew {
4292 sawNew = true
4293 }
4294 },
4295 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4296 }
4297 srv.Serve(&oneConnListener{
4298 conn: &rwTestConn{
4299 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4300 Writer: io.Discard,
4301 },
4302 })
4303 if !sawNew {
4304 t.Error("StateNew not seen")
4305 }
4306 }
4307
4308 type closeWriteTestConn struct {
4309 rwTestConn
4310 didCloseWrite bool
4311 }
4312
4313 func (c *closeWriteTestConn) CloseWrite() error {
4314 c.didCloseWrite = true
4315 return nil
4316 }
4317
4318 func TestCloseWrite(t *testing.T) {
4319 setParallel(t)
4320 var srv Server
4321 var testConn closeWriteTestConn
4322 c := ExportServerNewConn(&srv, &testConn)
4323 ExportCloseWriteAndWait(c)
4324 if !testConn.didCloseWrite {
4325 t.Error("didn't see CloseWrite call")
4326 }
4327 }
4328
4329
4330
4331
4332
4333
4334
4335
4336 func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
4337 func testServerFlushAndHijack(t *testing.T, mode testMode) {
4338 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4339 io.WriteString(w, "Hello, ")
4340 w.(Flusher).Flush()
4341 conn, buf, _ := w.(Hijacker).Hijack()
4342 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4343 if err := buf.Flush(); err != nil {
4344 t.Error(err)
4345 }
4346 if err := conn.Close(); err != nil {
4347 t.Error(err)
4348 }
4349 })).ts
4350 res, err := Get(ts.URL)
4351 if err != nil {
4352 t.Fatal(err)
4353 }
4354 defer res.Body.Close()
4355 all, err := io.ReadAll(res.Body)
4356 if err != nil {
4357 t.Fatal(err)
4358 }
4359 if want := "Hello, world!"; string(all) != want {
4360 t.Errorf("Got %q; want %q", all, want)
4361 }
4362 }
4363
4364
4365
4366
4367
4368
4369
4370 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4371 run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
4372 }
4373 func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
4374 if testing.Short() {
4375 t.Skip("skipping in -short mode")
4376 }
4377 const numReq = 3
4378 addrc := make(chan string, numReq)
4379 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4380 addrc <- r.RemoteAddr
4381 time.Sleep(500 * time.Millisecond)
4382 w.(Flusher).Flush()
4383 }), func(ts *httptest.Server) {
4384 ts.Config.WriteTimeout = 250 * time.Millisecond
4385 }).ts
4386
4387 errc := make(chan error, numReq)
4388 go func() {
4389 defer close(errc)
4390 for i := 0; i < numReq; i++ {
4391 res, err := Get(ts.URL)
4392 if res != nil {
4393 res.Body.Close()
4394 }
4395 errc <- err
4396 }
4397 }()
4398
4399 addrSeen := map[string]bool{}
4400 numOkay := 0
4401 for {
4402 select {
4403 case v := <-addrc:
4404 addrSeen[v] = true
4405 case err, ok := <-errc:
4406 if !ok {
4407 if len(addrSeen) != numReq {
4408 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4409 }
4410 if numOkay != 0 {
4411 t.Errorf("got %d successful client requests; want 0", numOkay)
4412 }
4413 return
4414 }
4415 if err == nil {
4416 numOkay++
4417 }
4418 }
4419 }
4420 }
4421
4422
4423
4424 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4425 run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
4426 }
4427 func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
4428 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4429 w.Header().Set("Transfer-Encoding", "foo")
4430 io.WriteString(w, "<html>")
4431 })).ts
4432 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4433 if err != nil {
4434 t.Fatalf("Dial: %v", err)
4435 }
4436 defer c.Close()
4437 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4438 t.Fatal(err)
4439 }
4440 bs := bufio.NewScanner(c)
4441 var got strings.Builder
4442 for bs.Scan() {
4443 if strings.TrimSpace(bs.Text()) == "" {
4444 break
4445 }
4446 got.WriteString(bs.Text())
4447 got.WriteByte('\n')
4448 }
4449 if err := bs.Err(); err != nil {
4450 t.Fatal(err)
4451 }
4452 if strings.Contains(got.String(), "Content-Length") {
4453 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4454 }
4455 if strings.Contains(got.String(), "Content-Type") {
4456 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4457 }
4458 }
4459
4460
4461
4462 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4463 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4464 "\r\n\r\n" +
4465 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4466 var buf bytes.Buffer
4467 conn := &rwTestConn{
4468 Reader: bytes.NewReader(req),
4469 Writer: &buf,
4470 closec: make(chan bool, 1),
4471 }
4472 ln := &oneConnListener{conn: conn}
4473 numReq := 0
4474 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4475 numReq++
4476 }))
4477 <-conn.closec
4478 if numReq != 2 {
4479 t.Errorf("num requests = %d; want 2", numReq)
4480 t.Logf("Res: %s", buf.Bytes())
4481 }
4482 }
4483
4484 func TestIssue13893_Expect100(t *testing.T) {
4485
4486 req := reqBytes(`PUT /readbody HTTP/1.1
4487 User-Agent: PycURL/7.22.0
4488 Host: 127.0.0.1:9000
4489 Accept: */*
4490 Expect: 100-continue
4491 Content-Length: 10
4492
4493 HelloWorld
4494
4495 `)
4496 var buf bytes.Buffer
4497 conn := &rwTestConn{
4498 Reader: bytes.NewReader(req),
4499 Writer: &buf,
4500 closec: make(chan bool, 1),
4501 }
4502 ln := &oneConnListener{conn: conn}
4503 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4504 if _, ok := r.Header["Expect"]; !ok {
4505 t.Error("Expect header should not be filtered out")
4506 }
4507 }))
4508 <-conn.closec
4509 }
4510
4511 func TestIssue11549_Expect100(t *testing.T) {
4512 req := reqBytes(`PUT /readbody HTTP/1.1
4513 User-Agent: PycURL/7.22.0
4514 Host: 127.0.0.1:9000
4515 Accept: */*
4516 Expect: 100-continue
4517 Content-Length: 10
4518
4519 HelloWorldPUT /noreadbody HTTP/1.1
4520 User-Agent: PycURL/7.22.0
4521 Host: 127.0.0.1:9000
4522 Accept: */*
4523 Expect: 100-continue
4524 Content-Length: 10
4525
4526 GET /should-be-ignored HTTP/1.1
4527 Host: foo
4528
4529 `)
4530 var buf strings.Builder
4531 conn := &rwTestConn{
4532 Reader: bytes.NewReader(req),
4533 Writer: &buf,
4534 closec: make(chan bool, 1),
4535 }
4536 ln := &oneConnListener{conn: conn}
4537 numReq := 0
4538 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4539 numReq++
4540 if r.URL.Path == "/readbody" {
4541 io.ReadAll(r.Body)
4542 }
4543 io.WriteString(w, "Hello world!")
4544 }))
4545 <-conn.closec
4546 if numReq != 2 {
4547 t.Errorf("num requests = %d; want 2", numReq)
4548 }
4549 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4550 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4551 }
4552 }
4553
4554
4555
4556 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4557 setParallel(t)
4558 conn := &testConn{closec: make(chan bool)}
4559 conn.readBuf.Write([]byte(fmt.Sprintf(
4560 "POST / HTTP/1.1\r\n" +
4561 "Host: test\r\n" +
4562 "Content-Length: 9999999999\r\n" +
4563 "\r\n" + strings.Repeat("a", 1<<20))))
4564
4565 ls := &oneConnListener{conn}
4566 var inHandlerLen int
4567 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4568 inHandlerLen = conn.readBuf.Len()
4569 rw.WriteHeader(404)
4570 }))
4571 <-conn.closec
4572 afterHandlerLen := conn.readBuf.Len()
4573
4574 if afterHandlerLen != inHandlerLen {
4575 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4576 }
4577 }
4578
4579 func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
4580 func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
4581 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4582 r.Body = nil
4583 fmt.Fprintf(w, "%v", r.RemoteAddr)
4584 }))
4585 get := func() string {
4586 res, err := cst.c.Get(cst.ts.URL)
4587 if err != nil {
4588 t.Fatal(err)
4589 }
4590 defer res.Body.Close()
4591 slurp, err := io.ReadAll(res.Body)
4592 if err != nil {
4593 t.Fatal(err)
4594 }
4595 return string(slurp)
4596 }
4597 a, b := get(), get()
4598 if a != b {
4599 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
4600 }
4601 }
4602
4603
4604
4605 func TestServerValidatesHostHeader(t *testing.T) {
4606 tests := []struct {
4607 proto string
4608 host string
4609 want int
4610 }{
4611 {"HTTP/0.9", "", 505},
4612
4613 {"HTTP/1.1", "", 400},
4614 {"HTTP/1.1", "Host: \r\n", 200},
4615 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4616 {"HTTP/1.1", "Host: foo.com\r\n", 200},
4617 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
4618 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
4619 {"HTTP/1.1", "Host: ::1\r\n", 200},
4620 {"HTTP/1.1", "Host: [::1]\r\n", 200},
4621 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
4622 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
4623 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4624 {"HTTP/1.1", "Host: \x06\r\n", 400},
4625 {"HTTP/1.1", "Host: \xff\r\n", 400},
4626 {"HTTP/1.1", "Host: {\r\n", 400},
4627 {"HTTP/1.1", "Host: }\r\n", 400},
4628 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
4629
4630
4631
4632 {"HTTP/1.0", "", 200},
4633 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
4634 {"HTTP/1.0", "Host: \xff\r\n", 400},
4635
4636
4637 {"PRI * HTTP/2.0", "", 200},
4638
4639
4640 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
4641
4642
4643 {"PRI / HTTP/2.0", "", 505},
4644 {"GET / HTTP/2.0", "", 505},
4645 {"GET / HTTP/3.0", "", 505},
4646 }
4647 for _, tt := range tests {
4648 conn := &testConn{closec: make(chan bool, 1)}
4649 methodTarget := "GET / "
4650 if !strings.HasPrefix(tt.proto, "HTTP/") {
4651 methodTarget = ""
4652 }
4653 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
4654
4655 ln := &oneConnListener{conn}
4656 srv := Server{
4657 ErrorLog: quietLog,
4658 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4659 }
4660 go srv.Serve(ln)
4661 <-conn.closec
4662 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4663 if err != nil {
4664 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
4665 continue
4666 }
4667 if res.StatusCode != tt.want {
4668 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
4669 }
4670 }
4671 }
4672
4673 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
4674 run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
4675 }
4676 func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
4677 const upgradeResponse = "upgrade here"
4678 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4679 conn, br, err := w.(Hijacker).Hijack()
4680 if err != nil {
4681 t.Error(err)
4682 return
4683 }
4684 defer conn.Close()
4685 if r.Method != "PRI" || r.RequestURI != "*" {
4686 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
4687 return
4688 }
4689 if !r.Close {
4690 t.Errorf("Request.Close = true; want false")
4691 }
4692 const want = "SM\r\n\r\n"
4693 buf := make([]byte, len(want))
4694 n, err := io.ReadFull(br, buf)
4695 if err != nil || string(buf[:n]) != want {
4696 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
4697 return
4698 }
4699 io.WriteString(conn, upgradeResponse)
4700 })).ts
4701
4702 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4703 if err != nil {
4704 t.Fatalf("Dial: %v", err)
4705 }
4706 defer c.Close()
4707 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
4708 slurp, err := io.ReadAll(c)
4709 if err != nil {
4710 t.Fatal(err)
4711 }
4712 if string(slurp) != upgradeResponse {
4713 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
4714 }
4715 }
4716
4717
4718
4719 func TestServerValidatesHeaders(t *testing.T) {
4720 setParallel(t)
4721 tests := []struct {
4722 header string
4723 want int
4724 }{
4725 {"", 200},
4726 {"Foo: bar\r\n", 200},
4727 {"X-Foo: bar\r\n", 200},
4728 {"Foo: a space\r\n", 200},
4729
4730 {"A space: foo\r\n", 400},
4731 {"foo\xffbar: foo\r\n", 400},
4732 {"foo\x00bar: foo\r\n", 400},
4733 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
4734
4735
4736 {"Foo : bar\r\n", 400},
4737 {"Foo\t: bar\r\n", 400},
4738
4739 {"foo: foo foo\r\n", 200},
4740 {"foo: foo\tfoo\r\n", 200},
4741 {"foo: foo\x00foo\r\n", 400},
4742 {"foo: foo\x7ffoo\r\n", 400},
4743 {"foo: foo\xfffoo\r\n", 200},
4744 }
4745 for _, tt := range tests {
4746 conn := &testConn{closec: make(chan bool, 1)}
4747 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
4748
4749 ln := &oneConnListener{conn}
4750 srv := Server{
4751 ErrorLog: quietLog,
4752 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4753 }
4754 go srv.Serve(ln)
4755 <-conn.closec
4756 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4757 if err != nil {
4758 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
4759 continue
4760 }
4761 if res.StatusCode != tt.want {
4762 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
4763 }
4764 }
4765 }
4766
4767 func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
4768 run(t, testServerRequestContextCancel_ServeHTTPDone)
4769 }
4770 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
4771 ctxc := make(chan context.Context, 1)
4772 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4773 ctx := r.Context()
4774 select {
4775 case <-ctx.Done():
4776 t.Error("should not be Done in ServeHTTP")
4777 default:
4778 }
4779 ctxc <- ctx
4780 }))
4781 res, err := cst.c.Get(cst.ts.URL)
4782 if err != nil {
4783 t.Fatal(err)
4784 }
4785 res.Body.Close()
4786 ctx := <-ctxc
4787 select {
4788 case <-ctx.Done():
4789 default:
4790 t.Error("context should be done after ServeHTTP completes")
4791 }
4792 }
4793
4794
4795
4796
4797
4798 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
4799 run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
4800 }
4801 func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
4802 inHandler := make(chan struct{})
4803 handlerDone := make(chan struct{})
4804 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4805 close(inHandler)
4806 <-r.Context().Done()
4807 close(handlerDone)
4808 })).ts
4809 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4810 if err != nil {
4811 t.Fatal(err)
4812 }
4813 defer c.Close()
4814 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
4815 <-inHandler
4816 c.Close()
4817 <-handlerDone
4818 }
4819
4820 func TestServerContext_ServerContextKey(t *testing.T) {
4821 run(t, testServerContext_ServerContextKey)
4822 }
4823 func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
4824 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4825 ctx := r.Context()
4826 got := ctx.Value(ServerContextKey)
4827 if _, ok := got.(*Server); !ok {
4828 t.Errorf("context value = %T; want *http.Server", got)
4829 }
4830 }))
4831 res, err := cst.c.Get(cst.ts.URL)
4832 if err != nil {
4833 t.Fatal(err)
4834 }
4835 res.Body.Close()
4836 }
4837
4838 func TestServerContext_LocalAddrContextKey(t *testing.T) {
4839 run(t, testServerContext_LocalAddrContextKey)
4840 }
4841 func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
4842 ch := make(chan any, 1)
4843 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4844 ch <- r.Context().Value(LocalAddrContextKey)
4845 }))
4846 if _, err := cst.c.Head(cst.ts.URL); err != nil {
4847 t.Fatal(err)
4848 }
4849
4850 host := cst.ts.Listener.Addr().String()
4851 got := <-ch
4852 if addr, ok := got.(net.Addr); !ok {
4853 t.Errorf("local addr value = %T; want net.Addr", got)
4854 } else if fmt.Sprint(addr) != host {
4855 t.Errorf("local addr = %v; want %v", addr, host)
4856 }
4857 }
4858
4859
4860 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
4861 setParallel(t)
4862 defer afterTest(t)
4863 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4864 w.Header().Set("Transfer-Encoding", "chunked")
4865 w.Write([]byte("hello"))
4866 }))
4867 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4868 const hdr = "Transfer-Encoding: chunked"
4869 if n := strings.Count(resp, hdr); n != 1 {
4870 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
4871 }
4872 }
4873
4874
4875 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
4876 setParallel(t)
4877 defer afterTest(t)
4878 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4879 w.Header().Set("Transfer-Encoding", "gzip")
4880 gz := gzip.NewWriter(w)
4881 gz.Write([]byte("hello"))
4882 gz.Close()
4883 }))
4884 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4885 for _, v := range []string{"gzip", "chunked"} {
4886 hdr := "Transfer-Encoding: " + v
4887 if n := strings.Count(resp, hdr); n != 1 {
4888 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
4889 }
4890 }
4891 }
4892
4893 func BenchmarkClientServer(b *testing.B) {
4894 run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
4895 }
4896 func benchmarkClientServer(b *testing.B, mode testMode) {
4897 b.ReportAllocs()
4898 b.StopTimer()
4899 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
4900 fmt.Fprintf(rw, "Hello world.\n")
4901 })).ts
4902 b.StartTimer()
4903
4904 c := ts.Client()
4905 for i := 0; i < b.N; i++ {
4906 res, err := c.Get(ts.URL)
4907 if err != nil {
4908 b.Fatal("Get:", err)
4909 }
4910 all, err := io.ReadAll(res.Body)
4911 res.Body.Close()
4912 if err != nil {
4913 b.Fatal("ReadAll:", err)
4914 }
4915 body := string(all)
4916 if body != "Hello world.\n" {
4917 b.Fatal("Got body:", body)
4918 }
4919 }
4920
4921 b.StopTimer()
4922 }
4923
4924 func BenchmarkClientServerParallel(b *testing.B) {
4925 for _, parallelism := range []int{4, 64} {
4926 b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
4927 run(b, func(b *testing.B, mode testMode) {
4928 benchmarkClientServerParallel(b, parallelism, mode)
4929 }, []testMode{http1Mode, https1Mode, http2Mode})
4930 })
4931 }
4932 }
4933
4934 func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
4935 b.ReportAllocs()
4936 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
4937 fmt.Fprintf(rw, "Hello world.\n")
4938 })).ts
4939 b.ResetTimer()
4940 b.SetParallelism(parallelism)
4941 b.RunParallel(func(pb *testing.PB) {
4942 c := ts.Client()
4943 for pb.Next() {
4944 res, err := c.Get(ts.URL)
4945 if err != nil {
4946 b.Logf("Get: %v", err)
4947 continue
4948 }
4949 all, err := io.ReadAll(res.Body)
4950 res.Body.Close()
4951 if err != nil {
4952 b.Logf("ReadAll: %v", err)
4953 continue
4954 }
4955 body := string(all)
4956 if body != "Hello world.\n" {
4957 panic("Got body: " + body)
4958 }
4959 }
4960 })
4961 }
4962
4963
4964
4965
4966
4967
4968
4969
4970
4971
4972 func BenchmarkServer(b *testing.B) {
4973 b.ReportAllocs()
4974
4975 if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" {
4976 n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N"))
4977 if err != nil {
4978 panic(err)
4979 }
4980 for i := 0; i < n; i++ {
4981 res, err := Get(url)
4982 if err != nil {
4983 log.Panicf("Get: %v", err)
4984 }
4985 all, err := io.ReadAll(res.Body)
4986 res.Body.Close()
4987 if err != nil {
4988 log.Panicf("ReadAll: %v", err)
4989 }
4990 body := string(all)
4991 if body != "Hello world.\n" {
4992 log.Panicf("Got body: %q", body)
4993 }
4994 }
4995 os.Exit(0)
4996 return
4997 }
4998
4999 var res = []byte("Hello world.\n")
5000 b.StopTimer()
5001 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5002 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5003 rw.Write(res)
5004 }))
5005 defer ts.Close()
5006 b.StartTimer()
5007
5008 cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkServer$")
5009 cmd.Env = append([]string{
5010 fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
5011 fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
5012 }, os.Environ()...)
5013 out, err := cmd.CombinedOutput()
5014 if err != nil {
5015 b.Errorf("Test failure: %v, with output: %s", err, out)
5016 }
5017 }
5018
5019
5020 func getNoBody(urlStr string) (*Response, error) {
5021 res, err := Get(urlStr)
5022 if err != nil {
5023 return nil, err
5024 }
5025 res.Body.Close()
5026 return res, nil
5027 }
5028
5029
5030
5031 func BenchmarkClient(b *testing.B) {
5032 b.ReportAllocs()
5033 b.StopTimer()
5034 defer afterTest(b)
5035
5036 var data = []byte("Hello world.\n")
5037 if server := os.Getenv("TEST_BENCH_SERVER"); server != "" {
5038
5039 port := os.Getenv("TEST_BENCH_SERVER_PORT")
5040 if port == "" {
5041 port = "0"
5042 }
5043 ln, err := net.Listen("tcp", "localhost:"+port)
5044 if err != nil {
5045 fmt.Fprintln(os.Stderr, err.Error())
5046 os.Exit(1)
5047 }
5048 fmt.Println(ln.Addr().String())
5049 HandleFunc("/", func(w ResponseWriter, r *Request) {
5050 r.ParseForm()
5051 if r.Form.Get("stop") != "" {
5052 os.Exit(0)
5053 }
5054 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5055 w.Write(data)
5056 })
5057 var srv Server
5058 log.Fatal(srv.Serve(ln))
5059 }
5060
5061
5062 ctx, cancel := context.WithCancel(context.Background())
5063 cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkClient$")
5064 cmd.Env = append(cmd.Environ(), "TEST_BENCH_SERVER=yes")
5065 cmd.Stderr = os.Stderr
5066 stdout, err := cmd.StdoutPipe()
5067 if err != nil {
5068 b.Fatal(err)
5069 }
5070 if err := cmd.Start(); err != nil {
5071 b.Fatalf("subprocess failed to start: %v", err)
5072 }
5073
5074 done := make(chan error, 1)
5075 go func() {
5076 done <- cmd.Wait()
5077 close(done)
5078 }()
5079 defer func() {
5080 cancel()
5081 <-done
5082 }()
5083
5084
5085
5086 bs := bufio.NewScanner(stdout)
5087 if !bs.Scan() {
5088 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5089 }
5090 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5091 if _, err := getNoBody(url); err != nil {
5092 b.Fatalf("initial probe of child process failed: %v", err)
5093 }
5094
5095
5096 b.StartTimer()
5097 for i := 0; i < b.N; i++ {
5098 res, err := Get(url)
5099 if err != nil {
5100 b.Fatalf("Get: %v", err)
5101 }
5102 body, err := io.ReadAll(res.Body)
5103 res.Body.Close()
5104 if err != nil {
5105 b.Fatalf("ReadAll: %v", err)
5106 }
5107 if !bytes.Equal(body, data) {
5108 b.Fatalf("Got body: %q", body)
5109 }
5110 }
5111 b.StopTimer()
5112
5113
5114 getNoBody(url + "?stop=yes")
5115 if err := <-done; err != nil {
5116 b.Fatalf("subprocess failed: %v", err)
5117 }
5118 }
5119
5120 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5121 b.ReportAllocs()
5122 req := reqBytes(`GET / HTTP/1.0
5123 Host: golang.org
5124 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5125 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5126 Accept-Encoding: gzip,deflate,sdch
5127 Accept-Language: en-US,en;q=0.8
5128 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5129 `)
5130 res := []byte("Hello world!\n")
5131
5132 conn := &testConn{
5133
5134
5135 closec: make(chan bool, 1),
5136 }
5137 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5138 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5139 rw.Write(res)
5140 })
5141 ln := new(oneConnListener)
5142 for i := 0; i < b.N; i++ {
5143 conn.readBuf.Reset()
5144 conn.writeBuf.Reset()
5145 conn.readBuf.Write(req)
5146 ln.conn = conn
5147 Serve(ln, handler)
5148 <-conn.closec
5149 }
5150 }
5151
5152
5153 type repeatReader struct {
5154 content []byte
5155 count int
5156 off int
5157 }
5158
5159 func (r *repeatReader) Read(p []byte) (n int, err error) {
5160 if r.count <= 0 {
5161 return 0, io.EOF
5162 }
5163 n = copy(p, r.content[r.off:])
5164 r.off += n
5165 if r.off == len(r.content) {
5166 r.count--
5167 r.off = 0
5168 }
5169 return
5170 }
5171
5172 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5173 b.ReportAllocs()
5174
5175 req := reqBytes(`GET / HTTP/1.1
5176 Host: golang.org
5177 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5178 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5179 Accept-Encoding: gzip,deflate,sdch
5180 Accept-Language: en-US,en;q=0.8
5181 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5182 `)
5183 res := []byte("Hello world!\n")
5184
5185 conn := &rwTestConn{
5186 Reader: &repeatReader{content: req, count: b.N},
5187 Writer: io.Discard,
5188 closec: make(chan bool, 1),
5189 }
5190 handled := 0
5191 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5192 handled++
5193 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5194 rw.Write(res)
5195 })
5196 ln := &oneConnListener{conn: conn}
5197 go Serve(ln, handler)
5198 <-conn.closec
5199 if b.N != handled {
5200 b.Errorf("b.N=%d but handled %d", b.N, handled)
5201 }
5202 }
5203
5204
5205
5206 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5207 b.ReportAllocs()
5208
5209 req := reqBytes(`GET / HTTP/1.1
5210 Host: golang.org
5211 `)
5212 res := []byte("Hello world!\n")
5213
5214 conn := &rwTestConn{
5215 Reader: &repeatReader{content: req, count: b.N},
5216 Writer: io.Discard,
5217 closec: make(chan bool, 1),
5218 }
5219 handled := 0
5220 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5221 handled++
5222 rw.Write(res)
5223 })
5224 ln := &oneConnListener{conn: conn}
5225 go Serve(ln, handler)
5226 <-conn.closec
5227 if b.N != handled {
5228 b.Errorf("b.N=%d but handled %d", b.N, handled)
5229 }
5230 }
5231
5232 const someResponse = "<html>some response</html>"
5233
5234
5235 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5236
5237
5238 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5239 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5240 w.Header().Set("Content-Type", "text/html")
5241 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5242 w.Write(response)
5243 }))
5244 }
5245
5246
5247 func BenchmarkServerHandlerNoLen(b *testing.B) {
5248 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5249 w.Header().Set("Content-Type", "text/html")
5250 w.Write(response)
5251 }))
5252 }
5253
5254
5255 func BenchmarkServerHandlerNoType(b *testing.B) {
5256 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5257 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5258 w.Write(response)
5259 }))
5260 }
5261
5262
5263 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5264 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5265 w.Write(response)
5266 }))
5267 }
5268
5269 func benchmarkHandler(b *testing.B, h Handler) {
5270 b.ReportAllocs()
5271 req := reqBytes(`GET / HTTP/1.1
5272 Host: golang.org
5273 `)
5274 conn := &rwTestConn{
5275 Reader: &repeatReader{content: req, count: b.N},
5276 Writer: io.Discard,
5277 closec: make(chan bool, 1),
5278 }
5279 handled := 0
5280 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5281 handled++
5282 h.ServeHTTP(rw, r)
5283 })
5284 ln := &oneConnListener{conn: conn}
5285 go Serve(ln, handler)
5286 <-conn.closec
5287 if b.N != handled {
5288 b.Errorf("b.N=%d but handled %d", b.N, handled)
5289 }
5290 }
5291
5292 func BenchmarkServerHijack(b *testing.B) {
5293 b.ReportAllocs()
5294 req := reqBytes(`GET / HTTP/1.1
5295 Host: golang.org
5296 `)
5297 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5298 conn, _, err := w.(Hijacker).Hijack()
5299 if err != nil {
5300 panic(err)
5301 }
5302 conn.Close()
5303 })
5304 conn := &rwTestConn{
5305 Writer: io.Discard,
5306 closec: make(chan bool, 1),
5307 }
5308 ln := &oneConnListener{conn: conn}
5309 for i := 0; i < b.N; i++ {
5310 conn.Reader = bytes.NewReader(req)
5311 ln.conn = conn
5312 Serve(ln, h)
5313 <-conn.closec
5314 }
5315 }
5316
5317 func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
5318 func benchmarkCloseNotifier(b *testing.B, mode testMode) {
5319 b.ReportAllocs()
5320 b.StopTimer()
5321 sawClose := make(chan bool)
5322 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
5323 <-rw.(CloseNotifier).CloseNotify()
5324 sawClose <- true
5325 })).ts
5326 b.StartTimer()
5327 for i := 0; i < b.N; i++ {
5328 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5329 if err != nil {
5330 b.Fatalf("error dialing: %v", err)
5331 }
5332 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5333 if err != nil {
5334 b.Fatal(err)
5335 }
5336 conn.Close()
5337 <-sawClose
5338 }
5339 b.StopTimer()
5340 }
5341
5342
5343 func TestConcurrentServerServe(t *testing.T) {
5344 setParallel(t)
5345 for i := 0; i < 100; i++ {
5346 ln1 := &oneConnListener{conn: nil}
5347 ln2 := &oneConnListener{conn: nil}
5348 srv := Server{}
5349 go func() { srv.Serve(ln1) }()
5350 go func() { srv.Serve(ln2) }()
5351 }
5352 }
5353
5354 func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
5355 func testServerIdleTimeout(t *testing.T, mode testMode) {
5356 if testing.Short() {
5357 t.Skip("skipping in short mode")
5358 }
5359 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5360 io.Copy(io.Discard, r.Body)
5361 io.WriteString(w, r.RemoteAddr)
5362 }), func(ts *httptest.Server) {
5363 ts.Config.ReadHeaderTimeout = 1 * time.Second
5364 ts.Config.IdleTimeout = 2 * time.Second
5365 }).ts
5366 c := ts.Client()
5367
5368 get := func() string {
5369 res, err := c.Get(ts.URL)
5370 if err != nil {
5371 t.Fatal(err)
5372 }
5373 defer res.Body.Close()
5374 slurp, err := io.ReadAll(res.Body)
5375 if err != nil {
5376 t.Fatal(err)
5377 }
5378 return string(slurp)
5379 }
5380
5381 a1, a2 := get(), get()
5382 if a1 != a2 {
5383 t.Fatalf("did requests on different connections")
5384 }
5385 time.Sleep(3 * time.Second)
5386 a3 := get()
5387 if a2 == a3 {
5388 t.Fatal("request three unexpectedly on same connection")
5389 }
5390
5391
5392 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5393 if err != nil {
5394 t.Fatal(err)
5395 }
5396 defer conn.Close()
5397 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5398 time.Sleep(2 * time.Second)
5399 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5400 t.Fatal("copy byte succeeded; want err")
5401 }
5402 }
5403
5404 func get(t *testing.T, c *Client, url string) string {
5405 res, err := c.Get(url)
5406 if err != nil {
5407 t.Fatal(err)
5408 }
5409 defer res.Body.Close()
5410 slurp, err := io.ReadAll(res.Body)
5411 if err != nil {
5412 t.Fatal(err)
5413 }
5414 return string(slurp)
5415 }
5416
5417
5418
5419 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5420 run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
5421 }
5422 func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
5423 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5424 io.WriteString(w, r.RemoteAddr)
5425 })).ts
5426
5427 c := ts.Client()
5428 tr := c.Transport.(*Transport)
5429
5430 get := func() string { return get(t, c, ts.URL) }
5431
5432 a1, a2 := get(), get()
5433 if a1 == a2 {
5434 t.Logf("made two requests from a single conn %q (as expected)", a1)
5435 } else {
5436 t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2)
5437 }
5438
5439
5440
5441
5442
5443 if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 {
5444 t.Errorf("found %d idle conns (%q); want 1", len(conns), conns)
5445 }
5446
5447
5448 ts.Config.SetKeepAlivesEnabled(false)
5449
5450 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5451 if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 {
5452 if d > 0 {
5453 t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns)
5454 }
5455 return false
5456 }
5457 return true
5458 })
5459
5460
5461
5462
5463 }
5464
5465 func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) }
5466 func testServerShutdown(t *testing.T, mode testMode) {
5467 var cst *clientServerTest
5468
5469 var once sync.Once
5470 statesRes := make(chan map[ConnState]int, 1)
5471 shutdownRes := make(chan error, 1)
5472 gotOnShutdown := make(chan struct{})
5473 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5474 first := false
5475 once.Do(func() {
5476 statesRes <- cst.ts.Config.ExportAllConnsByState()
5477 go func() {
5478 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5479 }()
5480 first = true
5481 })
5482
5483 if first {
5484
5485
5486
5487 <-gotOnShutdown
5488
5489
5490 for !t.Failed() {
5491 res, err := cst.c.Get(cst.ts.URL)
5492 if err != nil {
5493 break
5494 }
5495 out, _ := io.ReadAll(res.Body)
5496 res.Body.Close()
5497 if mode == http2Mode {
5498 t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5499 t.Logf("Retrying to work around https://go.dev/issue/59038.")
5500 continue
5501 }
5502 t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5503 }
5504 }
5505
5506 io.WriteString(w, r.RemoteAddr)
5507 })
5508
5509 cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
5510 srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) })
5511 })
5512
5513 out := get(t, cst.c, cst.ts.URL)
5514 t.Logf("%v: %q", cst.ts.URL, out)
5515
5516 if err := <-shutdownRes; err != nil {
5517 t.Fatalf("Shutdown: %v", err)
5518 }
5519 <-gotOnShutdown
5520
5521 if states := <-statesRes; states[StateActive] != 1 {
5522 t.Errorf("connection in wrong state, %v", states)
5523 }
5524 }
5525
5526 func TestServerShutdownStateNew(t *testing.T) { run(t, testServerShutdownStateNew) }
5527 func testServerShutdownStateNew(t *testing.T, mode testMode) {
5528 if testing.Short() {
5529 t.Skip("test takes 5-6 seconds; skipping in short mode")
5530 }
5531
5532 var connAccepted sync.WaitGroup
5533 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5534
5535 }), func(ts *httptest.Server) {
5536 ts.Config.ConnState = func(conn net.Conn, state ConnState) {
5537 if state == StateNew {
5538 connAccepted.Done()
5539 }
5540 }
5541 }).ts
5542
5543
5544 connAccepted.Add(1)
5545 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5546 if err != nil {
5547 t.Fatal(err)
5548 }
5549 defer c.Close()
5550
5551
5552
5553
5554
5555 connAccepted.Wait()
5556
5557 shutdownRes := make(chan error, 1)
5558 go func() {
5559 shutdownRes <- ts.Config.Shutdown(context.Background())
5560 }()
5561 readRes := make(chan error, 1)
5562 go func() {
5563 _, err := c.Read([]byte{0})
5564 readRes <- err
5565 }()
5566
5567
5568
5569
5570 const expectTimeout = 5 * time.Second
5571
5572 t0 := time.Now()
5573 select {
5574 case got := <-shutdownRes:
5575 d := time.Since(t0)
5576 if got != nil {
5577 t.Fatalf("shutdown error after %v: %v", d, err)
5578 }
5579 if d < expectTimeout/2 {
5580 t.Errorf("shutdown too soon after %v", d)
5581 }
5582 case <-time.After(expectTimeout * 3 / 2):
5583 t.Fatalf("timeout waiting for shutdown")
5584 }
5585
5586
5587
5588 if err := <-readRes; err == nil {
5589 t.Error("expected error from Read")
5590 }
5591 }
5592
5593
5594 func TestServerCloseDeadlock(t *testing.T) {
5595 var s Server
5596 s.Close()
5597 s.Close()
5598 }
5599
5600
5601
5602 func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) }
5603 func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
5604 if mode == http2Mode {
5605 restore := ExportSetH2GoawayTimeout(10 * time.Millisecond)
5606 defer restore()
5607 }
5608
5609 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
5610 defer cst.close()
5611 srv := cst.ts.Config
5612 srv.SetKeepAlivesEnabled(false)
5613 for try := 0; try < 2; try++ {
5614 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5615 if !srv.ExportAllConnsIdle() {
5616 if d > 0 {
5617 t.Logf("test server still has active conns after %v", d)
5618 }
5619 return false
5620 }
5621 return true
5622 })
5623 conns := 0
5624 var info httptrace.GotConnInfo
5625 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5626 GotConn: func(v httptrace.GotConnInfo) {
5627 conns++
5628 info = v
5629 },
5630 })
5631 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
5632 if err != nil {
5633 t.Fatal(err)
5634 }
5635 res, err := cst.c.Do(req)
5636 if err != nil {
5637 t.Fatal(err)
5638 }
5639 res.Body.Close()
5640 if conns != 1 {
5641 t.Fatalf("request %v: got %v conns, want 1", try, conns)
5642 }
5643 if info.Reused || info.WasIdle {
5644 t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
5645 }
5646 }
5647 }
5648
5649
5650
5651
5652 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
5653 func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
5654 runTimeSensitiveTest(t, []time.Duration{
5655 10 * time.Millisecond,
5656 50 * time.Millisecond,
5657 250 * time.Millisecond,
5658 time.Second,
5659 2 * time.Second,
5660 }, func(t *testing.T, timeout time.Duration) error {
5661 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5662 select {
5663 case <-time.After(2 * timeout):
5664 fmt.Fprint(w, "ok")
5665 case <-r.Context().Done():
5666 fmt.Fprint(w, r.Context().Err())
5667 }
5668 }), func(ts *httptest.Server) {
5669 ts.Config.ReadTimeout = timeout
5670 }).ts
5671
5672 c := ts.Client()
5673
5674 res, err := c.Get(ts.URL)
5675 if err != nil {
5676 return fmt.Errorf("Get: %v", err)
5677 }
5678 slurp, err := io.ReadAll(res.Body)
5679 res.Body.Close()
5680 if err != nil {
5681 return fmt.Errorf("Body ReadAll: %v", err)
5682 }
5683 if string(slurp) != "ok" {
5684 return fmt.Errorf("got: %q, want ok", slurp)
5685 }
5686 return nil
5687 })
5688 }
5689
5690
5691
5692
5693 func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
5694 run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
5695 }
5696 func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
5697 runTimeSensitiveTest(t, []time.Duration{
5698 10 * time.Millisecond,
5699 50 * time.Millisecond,
5700 250 * time.Millisecond,
5701 time.Second,
5702 2 * time.Second,
5703 }, func(t *testing.T, timeout time.Duration) error {
5704 ts := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
5705 ts.Config.ReadHeaderTimeout = timeout
5706 ts.Config.IdleTimeout = 0
5707 }).ts
5708
5709
5710
5711 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5712 if err != nil {
5713 t.Fatalf("dial failed: %v", err)
5714 }
5715 br := bufio.NewReader(conn)
5716 defer conn.Close()
5717
5718 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
5719 return fmt.Errorf("writing first request failed: %v", err)
5720 }
5721
5722 if _, err := ReadResponse(br, nil); err != nil {
5723 return fmt.Errorf("first response (before timeout) failed: %v", err)
5724 }
5725
5726
5727
5728 time.Sleep(timeout * 3 / 2)
5729
5730 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
5731 return fmt.Errorf("writing second request failed: %v", err)
5732 }
5733
5734 if _, err := ReadResponse(br, nil); err != nil {
5735 return fmt.Errorf("second response (after timeout) failed: %v", err)
5736 }
5737
5738 return nil
5739 })
5740 }
5741
5742
5743
5744 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
5745 for i, d := range durations {
5746 err := test(t, d)
5747 if err == nil {
5748 return
5749 }
5750 if i == len(durations)-1 {
5751 t.Fatalf("failed with duration %v: %v", d, err)
5752 }
5753 }
5754 }
5755
5756
5757
5758 func TestServerDuplicateBackgroundRead(t *testing.T) {
5759 run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
5760 }
5761 func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
5762 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
5763 testenv.SkipFlaky(t, 24826)
5764 }
5765
5766 goroutines := 5
5767 requests := 2000
5768 if testing.Short() {
5769 goroutines = 3
5770 requests = 100
5771 }
5772
5773 hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
5774
5775 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
5776
5777 var wg sync.WaitGroup
5778 for i := 0; i < goroutines; i++ {
5779 wg.Add(1)
5780 go func() {
5781 defer wg.Done()
5782 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
5783 if err != nil {
5784 t.Error(err)
5785 return
5786 }
5787 defer cn.Close()
5788
5789 wg.Add(1)
5790 go func() {
5791 defer wg.Done()
5792 io.Copy(io.Discard, cn)
5793 }()
5794
5795 for j := 0; j < requests; j++ {
5796 if t.Failed() {
5797 return
5798 }
5799 _, err := cn.Write(reqBytes)
5800 if err != nil {
5801 t.Error(err)
5802 return
5803 }
5804 }
5805 }()
5806 }
5807 wg.Wait()
5808 }
5809
5810
5811
5812
5813
5814
5815 func TestServerHijackGetsBackgroundByte(t *testing.T) {
5816 run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
5817 }
5818 func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
5819 if runtime.GOOS == "plan9" {
5820 t.Skip("skipping test; see https://golang.org/issue/18657")
5821 }
5822 done := make(chan struct{})
5823 inHandler := make(chan bool, 1)
5824 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5825 defer close(done)
5826
5827
5828 inHandler <- true
5829
5830 conn, buf, err := w.(Hijacker).Hijack()
5831 if err != nil {
5832 t.Error(err)
5833 return
5834 }
5835 defer conn.Close()
5836
5837 peek, err := buf.Reader.Peek(3)
5838 if string(peek) != "foo" || err != nil {
5839 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
5840 }
5841
5842 select {
5843 case <-r.Context().Done():
5844 t.Error("context unexpectedly canceled")
5845 default:
5846 }
5847 })).ts
5848
5849 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
5850 if err != nil {
5851 t.Fatal(err)
5852 }
5853 defer cn.Close()
5854 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
5855 t.Fatal(err)
5856 }
5857 <-inHandler
5858 if _, err := cn.Write([]byte("foo")); err != nil {
5859 t.Fatal(err)
5860 }
5861
5862 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
5863 t.Fatal(err)
5864 }
5865 <-done
5866 }
5867
5868
5869
5870
5871 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
5872 run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
5873 }
5874 func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
5875 if runtime.GOOS == "plan9" {
5876 t.Skip("skipping test; see https://golang.org/issue/18657")
5877 }
5878 done := make(chan struct{})
5879 const size = 8 << 10
5880 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5881 defer close(done)
5882
5883 conn, buf, err := w.(Hijacker).Hijack()
5884 if err != nil {
5885 t.Error(err)
5886 return
5887 }
5888 defer conn.Close()
5889 slurp, err := io.ReadAll(buf.Reader)
5890 if err != nil {
5891 t.Errorf("Copy: %v", err)
5892 }
5893 allX := true
5894 for _, v := range slurp {
5895 if v != 'x' {
5896 allX = false
5897 }
5898 }
5899 if len(slurp) != size {
5900 t.Errorf("read %d; want %d", len(slurp), size)
5901 } else if !allX {
5902 t.Errorf("read %q; want %d 'x'", slurp, size)
5903 }
5904 })).ts
5905
5906 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
5907 if err != nil {
5908 t.Fatal(err)
5909 }
5910 defer cn.Close()
5911 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
5912 strings.Repeat("x", size)); err != nil {
5913 t.Fatal(err)
5914 }
5915 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
5916 t.Fatal(err)
5917 }
5918
5919 <-done
5920 }
5921
5922
5923 func TestServerValidatesMethod(t *testing.T) {
5924 tests := []struct {
5925 method string
5926 want int
5927 }{
5928 {"GET", 200},
5929 {"GE(T", 400},
5930 }
5931 for _, tt := range tests {
5932 conn := &testConn{closec: make(chan bool, 1)}
5933 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
5934
5935 ln := &oneConnListener{conn}
5936 go Serve(ln, serve(200))
5937 <-conn.closec
5938 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5939 if err != nil {
5940 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
5941 continue
5942 }
5943 if res.StatusCode != tt.want {
5944 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
5945 }
5946 }
5947 }
5948
5949
5950 type eofListenerNotComparable []int
5951
5952 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
5953 func (eofListenerNotComparable) Addr() net.Addr { return nil }
5954 func (eofListenerNotComparable) Close() error { return nil }
5955
5956
5957 func TestServerListenNotComparableListener(t *testing.T) {
5958 var s Server
5959 s.Serve(make(eofListenerNotComparable, 1))
5960 }
5961
5962
5963 type countCloseListener struct {
5964 net.Listener
5965 closes int32
5966 }
5967
5968 func (p *countCloseListener) Close() error {
5969 var err error
5970 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
5971 err = p.Listener.Close()
5972 }
5973 return err
5974 }
5975
5976
5977 func TestServerCloseListenerOnce(t *testing.T) {
5978 setParallel(t)
5979 defer afterTest(t)
5980
5981 ln := newLocalListener(t)
5982 defer ln.Close()
5983
5984 cl := &countCloseListener{Listener: ln}
5985 server := &Server{}
5986 sdone := make(chan bool, 1)
5987
5988 go func() {
5989 server.Serve(cl)
5990 sdone <- true
5991 }()
5992 time.Sleep(10 * time.Millisecond)
5993 server.Shutdown(context.Background())
5994 ln.Close()
5995 <-sdone
5996
5997 nclose := atomic.LoadInt32(&cl.closes)
5998 if nclose != 1 {
5999 t.Errorf("Close calls = %v; want 1", nclose)
6000 }
6001 }
6002
6003
6004 func TestServerShutdownThenServe(t *testing.T) {
6005 var srv Server
6006 cl := &countCloseListener{Listener: nil}
6007 srv.Shutdown(context.Background())
6008 got := srv.Serve(cl)
6009 if got != ErrServerClosed {
6010 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6011 }
6012 nclose := atomic.LoadInt32(&cl.closes)
6013 if nclose != 1 {
6014 t.Errorf("Close calls = %v; want 1", nclose)
6015 }
6016 }
6017
6018
6019 func TestStripPortFromHost(t *testing.T) {
6020 mux := NewServeMux()
6021
6022 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6023 fmt.Fprintf(w, "OK")
6024 })
6025 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6026 fmt.Fprintf(w, "uh-oh!")
6027 })
6028
6029 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6030 rw := httptest.NewRecorder()
6031
6032 mux.ServeHTTP(rw, req)
6033
6034 response := rw.Body.String()
6035 if response != "OK" {
6036 t.Errorf("Response gotten was %q", response)
6037 }
6038 }
6039
6040 func TestServerContexts(t *testing.T) { run(t, testServerContexts) }
6041 func testServerContexts(t *testing.T, mode testMode) {
6042 type baseKey struct{}
6043 type connKey struct{}
6044 ch := make(chan context.Context, 1)
6045 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6046 ch <- r.Context()
6047 }), func(ts *httptest.Server) {
6048 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6049 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6050 t.Errorf("unexpected onceClose listener type %T", ln)
6051 }
6052 return context.WithValue(context.Background(), baseKey{}, "base")
6053 }
6054 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6055 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6056 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6057 }
6058 return context.WithValue(ctx, connKey{}, "conn")
6059 }
6060 }).ts
6061 res, err := ts.Client().Get(ts.URL)
6062 if err != nil {
6063 t.Fatal(err)
6064 }
6065 res.Body.Close()
6066 ctx := <-ch
6067 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6068 t.Errorf("base context key = %#v; want %q", got, want)
6069 }
6070 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6071 t.Errorf("conn context key = %#v; want %q", got, want)
6072 }
6073 }
6074
6075
6076 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6077 run(t, testConnContextNotModifyingAllContexts)
6078 }
6079 func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
6080 type connKey struct{}
6081 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6082 rw.Header().Set("Connection", "close")
6083 }), func(ts *httptest.Server) {
6084 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6085 if got := ctx.Value(connKey{}); got != nil {
6086 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6087 }
6088 return context.WithValue(ctx, connKey{}, "conn")
6089 }
6090 }).ts
6091
6092 var res *Response
6093 var err error
6094
6095 res, err = ts.Client().Get(ts.URL)
6096 if err != nil {
6097 t.Fatal(err)
6098 }
6099 res.Body.Close()
6100
6101 res, err = ts.Client().Get(ts.URL)
6102 if err != nil {
6103 t.Fatal(err)
6104 }
6105 res.Body.Close()
6106 }
6107
6108
6109
6110 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6111 run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
6112 }
6113 func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
6114 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6115 w.Write([]byte("Hello, World!"))
6116 })).ts
6117
6118 serverURL, err := url.Parse(cst.URL)
6119 if err != nil {
6120 t.Fatalf("Failed to parse server URL: %v", err)
6121 }
6122
6123 unsupportedTEs := []string{
6124 "fugazi",
6125 "foo-bar",
6126 "unknown",
6127 `" chunked"`,
6128 }
6129
6130 for _, badTE := range unsupportedTEs {
6131 http1ReqBody := fmt.Sprintf(""+
6132 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6133 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6134
6135 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6136 if err != nil {
6137 t.Errorf("%q. unexpected error: %v", badTE, err)
6138 continue
6139 }
6140
6141 wantBody := fmt.Sprintf("" +
6142 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6143 "Connection: close\r\n\r\nUnsupported transfer encoding")
6144
6145 if string(gotBody) != wantBody {
6146 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6147 }
6148 }
6149 }
6150
6151
6152 func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) }
6153 func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
6154 type setting struct {
6155 name string
6156 body []byte
6157
6158
6159
6160
6161 contentEncoding any
6162 wantContentType string
6163 }
6164
6165 settings := []*setting{
6166 {
6167 name: "gzip content-encoding, gzipped",
6168 contentEncoding: "application/gzip",
6169 wantContentType: "",
6170 body: func() []byte {
6171 buf := new(bytes.Buffer)
6172 gzw := gzip.NewWriter(buf)
6173 gzw.Write([]byte("doctype html><p>Hello</p>"))
6174 gzw.Close()
6175 return buf.Bytes()
6176 }(),
6177 },
6178 {
6179 name: "zlib content-encoding, zlibbed",
6180 contentEncoding: "application/zlib",
6181 wantContentType: "",
6182 body: func() []byte {
6183 buf := new(bytes.Buffer)
6184 zw := zlib.NewWriter(buf)
6185 zw.Write([]byte("doctype html><p>Hello</p>"))
6186 zw.Close()
6187 return buf.Bytes()
6188 }(),
6189 },
6190 {
6191 name: "no content-encoding",
6192 wantContentType: "application/x-gzip",
6193 body: func() []byte {
6194 buf := new(bytes.Buffer)
6195 gzw := gzip.NewWriter(buf)
6196 gzw.Write([]byte("doctype html><p>Hello</p>"))
6197 gzw.Close()
6198 return buf.Bytes()
6199 }(),
6200 },
6201 {
6202 name: "phony content-encoding",
6203 contentEncoding: "foo/bar",
6204 body: []byte("doctype html><p>Hello</p>"),
6205 },
6206 {
6207 name: "empty but set content-encoding",
6208 contentEncoding: "",
6209 wantContentType: "audio/mpeg",
6210 body: []byte("ID3"),
6211 },
6212 }
6213
6214 for _, tt := range settings {
6215 t.Run(tt.name, func(t *testing.T) {
6216 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6217 if tt.contentEncoding != nil {
6218 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6219 }
6220 rw.Write(tt.body)
6221 }))
6222
6223 res, err := cst.c.Get(cst.ts.URL)
6224 if err != nil {
6225 t.Fatalf("Failed to fetch URL: %v", err)
6226 }
6227 defer res.Body.Close()
6228
6229 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6230 if w != nil {
6231 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6232 } else if g != "" {
6233 t.Errorf("Unexpected Content-Encoding %q", g)
6234 }
6235 }
6236
6237 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6238 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6239 }
6240 })
6241 }
6242 }
6243
6244
6245
6246 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6247 run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
6248 }
6249 func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
6250 if testing.Short() {
6251 t.Skip("skipping in short mode")
6252 }
6253
6254 pc, curFile, _, _ := runtime.Caller(0)
6255 curFileBaseName := filepath.Base(curFile)
6256 testFuncName := runtime.FuncForPC(pc).Name()
6257
6258 timeoutMsg := "timed out here!"
6259
6260 tests := []struct {
6261 name string
6262 mustTimeout bool
6263 wantResp string
6264 }{
6265 {
6266 name: "return before timeout",
6267 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6268 },
6269 {
6270 name: "return after timeout",
6271 mustTimeout: true,
6272 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6273 len(timeoutMsg), timeoutMsg),
6274 },
6275 }
6276
6277 for _, tt := range tests {
6278 tt := tt
6279 t.Run(tt.name, func(t *testing.T) {
6280 exitHandler := make(chan bool, 1)
6281 defer close(exitHandler)
6282 lastLine := make(chan int, 1)
6283
6284 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6285 w.WriteHeader(404)
6286 w.WriteHeader(404)
6287 w.WriteHeader(404)
6288 w.WriteHeader(404)
6289 _, _, line, _ := runtime.Caller(0)
6290 lastLine <- line
6291 <-exitHandler
6292 })
6293
6294 if !tt.mustTimeout {
6295 exitHandler <- true
6296 }
6297
6298 logBuf := new(strings.Builder)
6299 srvLog := log.New(logBuf, "", 0)
6300
6301 dur := 20 * time.Millisecond
6302 if !tt.mustTimeout {
6303
6304 dur = 10 * time.Second
6305 }
6306 th := TimeoutHandler(sh, dur, timeoutMsg)
6307 cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
6308 defer cst.close()
6309
6310 res, err := cst.c.Get(cst.ts.URL)
6311 if err != nil {
6312 t.Fatalf("Unexpected error: %v", err)
6313 }
6314
6315
6316
6317 res.Header.Del("Date")
6318 res.Header.Del("Content-Type")
6319
6320
6321 blob, _ := httputil.DumpResponse(res, true)
6322 if g, w := string(blob), tt.wantResp; g != w {
6323 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6324 }
6325
6326
6327
6328 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6329 if g, w := len(logEntries), 3; g != w {
6330 blob, _ := json.MarshalIndent(logEntries, "", " ")
6331 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6332 }
6333
6334 lastSpuriousLine := <-lastLine
6335 firstSpuriousLine := lastSpuriousLine - 3
6336
6337
6338 for i, logEntry := range logEntries {
6339 wantLine := firstSpuriousLine + i
6340 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6341 testFuncName, curFileBaseName, wantLine)
6342 re := regexp.MustCompile(pat)
6343 if !re.MatchString(logEntry) {
6344 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6345 }
6346 }
6347 })
6348 }
6349 }
6350
6351
6352
6353
6354 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6355 conn, err := net.Dial("tcp", host)
6356 if err != nil {
6357 return nil, err
6358 }
6359 defer conn.Close()
6360
6361 if _, err := conn.Write(http1ReqBody); err != nil {
6362 return nil, err
6363 }
6364 return io.ReadAll(conn)
6365 }
6366
6367 func BenchmarkResponseStatusLine(b *testing.B) {
6368 b.ReportAllocs()
6369 b.RunParallel(func(pb *testing.PB) {
6370 bw := bufio.NewWriter(io.Discard)
6371 var buf3 [3]byte
6372 for pb.Next() {
6373 Export_writeStatusLine(bw, true, 200, buf3[:])
6374 }
6375 })
6376 }
6377
6378 func TestDisableKeepAliveUpgrade(t *testing.T) {
6379 run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
6380 }
6381 func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
6382 if testing.Short() {
6383 t.Skip("skipping in short mode")
6384 }
6385
6386 s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6387 w.Header().Set("Connection", "Upgrade")
6388 w.Header().Set("Upgrade", "someProto")
6389 w.WriteHeader(StatusSwitchingProtocols)
6390 c, buf, err := w.(Hijacker).Hijack()
6391 if err != nil {
6392 return
6393 }
6394 defer c.Close()
6395
6396
6397
6398 io.Copy(c, buf)
6399 }), func(ts *httptest.Server) {
6400 ts.Config.SetKeepAlivesEnabled(false)
6401 }).ts
6402
6403 cl := s.Client()
6404 cl.Transport.(*Transport).DisableKeepAlives = true
6405
6406 resp, err := cl.Get(s.URL)
6407 if err != nil {
6408 t.Fatalf("failed to perform request: %v", err)
6409 }
6410 defer resp.Body.Close()
6411
6412 if resp.StatusCode != StatusSwitchingProtocols {
6413 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6414 }
6415
6416 rwc, ok := resp.Body.(io.ReadWriteCloser)
6417 if !ok {
6418 t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
6419 }
6420
6421 _, err = rwc.Write([]byte("hello"))
6422 if err != nil {
6423 t.Fatalf("failed to write to body: %v", err)
6424 }
6425
6426 b := make([]byte, 5)
6427 _, err = io.ReadFull(rwc, b)
6428 if err != nil {
6429 t.Fatalf("failed to read from body: %v", err)
6430 }
6431
6432 if string(b) != "hello" {
6433 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6434 }
6435 }
6436
6437 type tlogWriter struct{ t *testing.T }
6438
6439 func (w tlogWriter) Write(p []byte) (int, error) {
6440 w.t.Log(string(p))
6441 return len(p), nil
6442 }
6443
6444 func TestWriteHeaderSwitchingProtocols(t *testing.T) {
6445 run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
6446 }
6447 func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
6448 const wantBody = "want"
6449 const wantUpgrade = "someProto"
6450 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6451 w.Header().Set("Connection", "Upgrade")
6452 w.Header().Set("Upgrade", wantUpgrade)
6453 w.WriteHeader(StatusSwitchingProtocols)
6454 NewResponseController(w).Flush()
6455
6456
6457 w.WriteHeader(200)
6458 if _, err := w.Write([]byte("x")); err == nil {
6459 t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
6460 }
6461
6462 c, _, err := NewResponseController(w).Hijack()
6463 if err != nil {
6464 t.Errorf("Hijack: %v", err)
6465 return
6466 }
6467 defer c.Close()
6468 if _, err := c.Write([]byte(wantBody)); err != nil {
6469 t.Errorf("Write to hijacked body: %v", err)
6470 }
6471 }), func(ts *httptest.Server) {
6472
6473 ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
6474 }).ts
6475
6476 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6477 if err != nil {
6478 t.Fatalf("net.Dial: %v", err)
6479 }
6480 _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
6481 if err != nil {
6482 t.Fatalf("conn.Write: %v", err)
6483 }
6484 defer conn.Close()
6485
6486 r := bufio.NewReader(conn)
6487 res, err := ReadResponse(r, &Request{Method: "GET"})
6488 if err != nil {
6489 t.Fatal("ReadResponse error:", err)
6490 }
6491 if res.StatusCode != StatusSwitchingProtocols {
6492 t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
6493 }
6494 if got := res.Header.Get("Upgrade"); got != wantUpgrade {
6495 t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
6496 }
6497 body, err := io.ReadAll(r)
6498 if err != nil {
6499 t.Error(err)
6500 }
6501 if string(body) != wantBody {
6502 t.Errorf("Response body = %q, want %q", string(body), wantBody)
6503 }
6504 }
6505
6506 func TestMuxRedirectRelative(t *testing.T) {
6507 setParallel(t)
6508 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
6509 if err != nil {
6510 t.Errorf("%s", err)
6511 }
6512 mux := NewServeMux()
6513 resp := httptest.NewRecorder()
6514 mux.ServeHTTP(resp, req)
6515 if got, want := resp.Header().Get("Location"), "/"; got != want {
6516 t.Errorf("Location header expected %q; got %q", want, got)
6517 }
6518 if got, want := resp.Code, StatusMovedPermanently; got != want {
6519 t.Errorf("Expected response code %d; got %d", want, got)
6520 }
6521 }
6522
6523
6524 func TestQuerySemicolon(t *testing.T) {
6525 t.Cleanup(func() { afterTest(t) })
6526
6527 tests := []struct {
6528 query string
6529 xNoSemicolons string
6530 xWithSemicolons string
6531 expectParseFormErr bool
6532 }{
6533 {"?a=1;x=bad&x=good", "good", "bad", true},
6534 {"?a=1;b=bad&x=good", "good", "good", true},
6535 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
6536 {"?a=1;x=good;x=bad", "", "good", true},
6537 }
6538
6539 run(t, func(t *testing.T, mode testMode) {
6540 for _, tt := range tests {
6541 t.Run(tt.query+"/allow=false", func(t *testing.T) {
6542 allowSemicolons := false
6543 testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr)
6544 })
6545 t.Run(tt.query+"/allow=true", func(t *testing.T) {
6546 allowSemicolons, expectParseFormErr := true, false
6547 testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr)
6548 })
6549 }
6550 })
6551 }
6552
6553 func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) {
6554 writeBackX := func(w ResponseWriter, r *Request) {
6555 x := r.URL.Query().Get("x")
6556 if expectParseFormErr {
6557 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
6558 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
6559 }
6560 } else {
6561 if err := r.ParseForm(); err != nil {
6562 t.Errorf("expected no error from ParseForm, got %v", err)
6563 }
6564 }
6565 if got := r.FormValue("x"); x != got {
6566 t.Errorf("got %q from FormValue, want %q", got, x)
6567 }
6568 fmt.Fprintf(w, "%s", x)
6569 }
6570
6571 h := Handler(HandlerFunc(writeBackX))
6572 if allowSemicolons {
6573 h = AllowQuerySemicolons(h)
6574 }
6575
6576 logBuf := &strings.Builder{}
6577 ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
6578 ts.Config.ErrorLog = log.New(logBuf, "", 0)
6579 }).ts
6580
6581 req, _ := NewRequest("GET", ts.URL+query, nil)
6582 res, err := ts.Client().Do(req)
6583 if err != nil {
6584 t.Fatal(err)
6585 }
6586 slurp, _ := io.ReadAll(res.Body)
6587 res.Body.Close()
6588 if got, want := res.StatusCode, 200; got != want {
6589 t.Errorf("Status = %d; want = %d", got, want)
6590 }
6591 if got, want := string(slurp), wantX; got != want {
6592 t.Errorf("Body = %q; want = %q", got, want)
6593 }
6594 }
6595
6596 func TestMaxBytesHandler(t *testing.T) {
6597 setParallel(t)
6598 defer afterTest(t)
6599
6600 for _, maxSize := range []int64{100, 1_000, 1_000_000} {
6601 for _, requestSize := range []int64{100, 1_000, 1_000_000} {
6602 t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
6603 func(t *testing.T) {
6604 run(t, func(t *testing.T, mode testMode) {
6605 testMaxBytesHandler(t, mode, maxSize, requestSize)
6606 })
6607 })
6608 }
6609 }
6610 }
6611
6612 func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
6613 var (
6614 handlerN int64
6615 handlerErr error
6616 )
6617 echo := HandlerFunc(func(w ResponseWriter, r *Request) {
6618 var buf bytes.Buffer
6619 handlerN, handlerErr = io.Copy(&buf, r.Body)
6620 io.Copy(w, &buf)
6621 })
6622
6623 ts := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize)).ts
6624 defer ts.Close()
6625
6626 c := ts.Client()
6627
6628 body := strings.Repeat("a", int(requestSize))
6629 var wg sync.WaitGroup
6630 defer wg.Wait()
6631 getBody := func() (io.ReadCloser, error) {
6632 wg.Add(1)
6633 body := &wgReadCloser{
6634 Reader: strings.NewReader(body),
6635 wg: &wg,
6636 }
6637 return body, nil
6638 }
6639 reqBody, _ := getBody()
6640 req, err := NewRequest("POST", ts.URL, reqBody)
6641 if err != nil {
6642 reqBody.Close()
6643 t.Fatal(err)
6644 }
6645 req.ContentLength = int64(len(body))
6646 req.GetBody = getBody
6647 req.Header.Set("Content-Type", "text/plain")
6648
6649 var buf strings.Builder
6650 res, err := c.Do(req)
6651 if err != nil {
6652 t.Errorf("unexpected connection error: %v", err)
6653 } else {
6654 _, err = io.Copy(&buf, res.Body)
6655 res.Body.Close()
6656 if err != nil {
6657 t.Errorf("unexpected read error: %v", err)
6658 }
6659 }
6660 if handlerN > maxSize {
6661 t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
6662 }
6663 if requestSize > maxSize && handlerErr == nil {
6664 t.Error("expected error on handler side; got nil")
6665 }
6666 if requestSize <= maxSize {
6667 if handlerErr != nil {
6668 t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
6669 }
6670 if handlerN != requestSize {
6671 t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
6672 }
6673 }
6674 if buf.Len() != int(handlerN) {
6675 t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
6676 }
6677 }
6678
6679 func TestEarlyHints(t *testing.T) {
6680 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
6681 h := w.Header()
6682 h.Add("Link", "</style.css>; rel=preload; as=style")
6683 h.Add("Link", "</script.js>; rel=preload; as=script")
6684 w.WriteHeader(StatusEarlyHints)
6685
6686 h.Add("Link", "</foo.js>; rel=preload; as=script")
6687 w.WriteHeader(StatusEarlyHints)
6688
6689 w.Write([]byte("stuff"))
6690 }))
6691
6692 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
6693 expected := "HTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\nDate: "
6694 if !strings.Contains(got, expected) {
6695 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
6696 }
6697 }
6698 func TestProcessing(t *testing.T) {
6699 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
6700 w.WriteHeader(StatusProcessing)
6701 w.Write([]byte("stuff"))
6702 }))
6703
6704 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
6705 expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: "
6706 if !strings.Contains(got, expected) {
6707 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
6708 }
6709 }
6710
6711 func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) }
6712 func testParseFormCleanup(t *testing.T, mode testMode) {
6713 if mode == http2Mode {
6714 t.Skip("https://go.dev/issue/20253")
6715 }
6716
6717 const maxMemory = 1024
6718 const key = "file"
6719
6720 if runtime.GOOS == "windows" {
6721
6722 t.Skip("https://go.dev/issue/25965")
6723 }
6724
6725 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6726 r.ParseMultipartForm(maxMemory)
6727 f, _, err := r.FormFile(key)
6728 if err != nil {
6729 t.Errorf("r.FormFile(%q) = %v", key, err)
6730 return
6731 }
6732 of, ok := f.(*os.File)
6733 if !ok {
6734 t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f)
6735 return
6736 }
6737 w.Write([]byte(of.Name()))
6738 }))
6739
6740 fBuf := new(bytes.Buffer)
6741 mw := multipart.NewWriter(fBuf)
6742 mf, err := mw.CreateFormFile(key, "myfile.txt")
6743 if err != nil {
6744 t.Fatal(err)
6745 }
6746 if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil {
6747 t.Fatal(err)
6748 }
6749 if err := mw.Close(); err != nil {
6750 t.Fatal(err)
6751 }
6752 req, err := NewRequest("POST", cst.ts.URL, fBuf)
6753 if err != nil {
6754 t.Fatal(err)
6755 }
6756 req.Header.Set("Content-Type", mw.FormDataContentType())
6757 res, err := cst.c.Do(req)
6758 if err != nil {
6759 t.Fatal(err)
6760 }
6761 defer res.Body.Close()
6762 fname, err := io.ReadAll(res.Body)
6763 if err != nil {
6764 t.Fatal(err)
6765 }
6766 cst.close()
6767 if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) {
6768 t.Errorf("file %q exists after HTTP handler returned", string(fname))
6769 }
6770 }
6771
6772 func TestHeadBody(t *testing.T) {
6773 const identityMode = false
6774 const chunkedMode = true
6775 run(t, func(t *testing.T, mode testMode) {
6776 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
6777 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
6778 })
6779 }
6780
6781 func TestGetBody(t *testing.T) {
6782 const identityMode = false
6783 const chunkedMode = true
6784 run(t, func(t *testing.T, mode testMode) {
6785 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
6786 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
6787 })
6788 }
6789
6790 func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
6791 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6792 b, err := io.ReadAll(r.Body)
6793 if err != nil {
6794 t.Errorf("server reading body: %v", err)
6795 return
6796 }
6797 w.Header().Set("X-Request-Body", string(b))
6798 w.Header().Set("Content-Length", "0")
6799 }))
6800 defer cst.close()
6801 for _, reqBody := range []string{
6802 "",
6803 "",
6804 "request_body",
6805 "",
6806 } {
6807 var bodyReader io.Reader
6808 if reqBody != "" {
6809 bodyReader = strings.NewReader(reqBody)
6810 if chunked {
6811 bodyReader = bufio.NewReader(bodyReader)
6812 }
6813 }
6814 req, err := NewRequest(method, cst.ts.URL, bodyReader)
6815 if err != nil {
6816 t.Fatal(err)
6817 }
6818 res, err := cst.c.Do(req)
6819 if err != nil {
6820 t.Fatal(err)
6821 }
6822 res.Body.Close()
6823 if got, want := res.StatusCode, 200; got != want {
6824 t.Errorf("%v request with %d-byte body: StatusCode = %v, want %v", method, len(reqBody), got, want)
6825 }
6826 if got, want := res.Header.Get("X-Request-Body"), reqBody; got != want {
6827 t.Errorf("%v request with %d-byte body: handler read body %q, want %q", method, len(reqBody), got, want)
6828 }
6829 }
6830 }
6831
6832
6833
6834 func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) }
6835 func testDisableContentLength(t *testing.T, mode testMode) {
6836 if mode == http2Mode {
6837 t.Skip("skipping until h2_bundle.go is updated; see https://go-review.googlesource.com/c/net/+/471535")
6838 }
6839
6840 noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6841 w.Header()["Content-Length"] = nil
6842 fmt.Fprintf(w, "OK")
6843 }))
6844
6845 res, err := noCL.c.Get(noCL.ts.URL)
6846 if err != nil {
6847 t.Fatal(err)
6848 }
6849 if got, haveCL := res.Header["Content-Length"]; haveCL {
6850 t.Errorf("Unexpected Content-Length: %q", got)
6851 }
6852 if err := res.Body.Close(); err != nil {
6853 t.Fatal(err)
6854 }
6855
6856 withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6857 fmt.Fprintf(w, "OK")
6858 }))
6859
6860 res, err = withCL.c.Get(withCL.ts.URL)
6861 if err != nil {
6862 t.Fatal(err)
6863 }
6864 if got := res.Header.Get("Content-Length"); got != "2" {
6865 t.Errorf("Content-Length: %q; want 2", got)
6866 }
6867 if err := res.Body.Close(); err != nil {
6868 t.Fatal(err)
6869 }
6870 }
6871
View as plain text