Source file src/net/http/transfer_test.go

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"crypto/rand"
    11  	"fmt"
    12  	"io"
    13  	"os"
    14  	"reflect"
    15  	"strings"
    16  	"testing"
    17  )
    18  
    19  func TestBodyReadBadTrailer(t *testing.T) {
    20  	b := &body{
    21  		src: strings.NewReader("foobar"),
    22  		hdr: true, // force reading the trailer
    23  		r:   bufio.NewReader(strings.NewReader("")),
    24  	}
    25  	buf := make([]byte, 7)
    26  	n, err := b.Read(buf[:3])
    27  	got := string(buf[:n])
    28  	if got != "foo" || err != nil {
    29  		t.Fatalf(`first Read = %d (%q), %v; want 3 ("foo")`, n, got, err)
    30  	}
    31  
    32  	n, err = b.Read(buf[:])
    33  	got = string(buf[:n])
    34  	if got != "bar" || err != nil {
    35  		t.Fatalf(`second Read = %d (%q), %v; want 3 ("bar")`, n, got, err)
    36  	}
    37  
    38  	n, err = b.Read(buf[:])
    39  	got = string(buf[:n])
    40  	if err == nil {
    41  		t.Errorf("final Read was successful (%q), expected error from trailer read", got)
    42  	}
    43  }
    44  
    45  func TestFinalChunkedBodyReadEOF(t *testing.T) {
    46  	res, err := ReadResponse(bufio.NewReader(strings.NewReader(
    47  		"HTTP/1.1 200 OK\r\n"+
    48  			"Transfer-Encoding: chunked\r\n"+
    49  			"\r\n"+
    50  			"0a\r\n"+
    51  			"Body here\n\r\n"+
    52  			"09\r\n"+
    53  			"continued\r\n"+
    54  			"0\r\n"+
    55  			"\r\n")), nil)
    56  	if err != nil {
    57  		t.Fatal(err)
    58  	}
    59  	want := "Body here\ncontinued"
    60  	buf := make([]byte, len(want))
    61  	n, err := res.Body.Read(buf)
    62  	if n != len(want) || err != io.EOF {
    63  		t.Logf("body = %#v", res.Body)
    64  		t.Errorf("Read = %v, %v; want %d, EOF", n, err, len(want))
    65  	}
    66  	if string(buf) != want {
    67  		t.Errorf("buf = %q; want %q", buf, want)
    68  	}
    69  }
    70  
    71  func TestDetectInMemoryReaders(t *testing.T) {
    72  	pr, _ := io.Pipe()
    73  	tests := []struct {
    74  		r    io.Reader
    75  		want bool
    76  	}{
    77  		{pr, false},
    78  
    79  		{bytes.NewReader(nil), true},
    80  		{bytes.NewBuffer(nil), true},
    81  		{strings.NewReader(""), true},
    82  
    83  		{io.NopCloser(pr), false},
    84  
    85  		{io.NopCloser(bytes.NewReader(nil)), true},
    86  		{io.NopCloser(bytes.NewBuffer(nil)), true},
    87  		{io.NopCloser(strings.NewReader("")), true},
    88  	}
    89  	for i, tt := range tests {
    90  		got := isKnownInMemoryReader(tt.r)
    91  		if got != tt.want {
    92  			t.Errorf("%d: got = %v; want %v", i, got, tt.want)
    93  		}
    94  	}
    95  }
    96  
    97  type mockTransferWriter struct {
    98  	CalledReader io.Reader
    99  	WriteCalled  bool
   100  }
   101  
   102  var _ io.ReaderFrom = (*mockTransferWriter)(nil)
   103  
   104  func (w *mockTransferWriter) ReadFrom(r io.Reader) (int64, error) {
   105  	w.CalledReader = r
   106  	return io.Copy(io.Discard, r)
   107  }
   108  
   109  func (w *mockTransferWriter) Write(p []byte) (int, error) {
   110  	w.WriteCalled = true
   111  	return io.Discard.Write(p)
   112  }
   113  
   114  func TestTransferWriterWriteBodyReaderTypes(t *testing.T) {
   115  	fileType := reflect.TypeFor[*os.File]()
   116  	bufferType := reflect.TypeFor[*bytes.Buffer]()
   117  
   118  	nBytes := int64(1 << 10)
   119  	newFileFunc := func() (r io.Reader, done func(), err error) {
   120  		f, err := os.CreateTemp("", "net-http-newfilefunc")
   121  		if err != nil {
   122  			return nil, nil, err
   123  		}
   124  
   125  		// Write some bytes to the file to enable reading.
   126  		if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
   127  			return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
   128  		}
   129  		if _, err := f.Seek(0, 0); err != nil {
   130  			return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
   131  		}
   132  
   133  		done = func() {
   134  			f.Close()
   135  			os.Remove(f.Name())
   136  		}
   137  
   138  		return f, done, nil
   139  	}
   140  
   141  	newBufferFunc := func() (io.Reader, func(), error) {
   142  		return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
   143  	}
   144  
   145  	cases := []struct {
   146  		name             string
   147  		bodyFunc         func() (io.Reader, func(), error)
   148  		method           string
   149  		contentLength    int64
   150  		transferEncoding []string
   151  		limitedReader    bool
   152  		expectedReader   reflect.Type
   153  		expectedWrite    bool
   154  	}{
   155  		{
   156  			name:           "file, non-chunked, size set",
   157  			bodyFunc:       newFileFunc,
   158  			method:         "PUT",
   159  			contentLength:  nBytes,
   160  			limitedReader:  true,
   161  			expectedReader: fileType,
   162  		},
   163  		{
   164  			name:   "file, non-chunked, size set, nopCloser wrapped",
   165  			method: "PUT",
   166  			bodyFunc: func() (io.Reader, func(), error) {
   167  				r, cleanup, err := newFileFunc()
   168  				return io.NopCloser(r), cleanup, err
   169  			},
   170  			contentLength:  nBytes,
   171  			limitedReader:  true,
   172  			expectedReader: fileType,
   173  		},
   174  		{
   175  			name:           "file, non-chunked, negative size",
   176  			method:         "PUT",
   177  			bodyFunc:       newFileFunc,
   178  			contentLength:  -1,
   179  			expectedReader: fileType,
   180  		},
   181  		{
   182  			name:           "file, non-chunked, CONNECT, negative size",
   183  			method:         "CONNECT",
   184  			bodyFunc:       newFileFunc,
   185  			contentLength:  -1,
   186  			expectedReader: fileType,
   187  		},
   188  		{
   189  			name:             "file, chunked",
   190  			method:           "PUT",
   191  			bodyFunc:         newFileFunc,
   192  			transferEncoding: []string{"chunked"},
   193  			expectedWrite:    true,
   194  		},
   195  		{
   196  			name:           "buffer, non-chunked, size set",
   197  			bodyFunc:       newBufferFunc,
   198  			method:         "PUT",
   199  			contentLength:  nBytes,
   200  			limitedReader:  true,
   201  			expectedReader: bufferType,
   202  		},
   203  		{
   204  			name:   "buffer, non-chunked, size set, nopCloser wrapped",
   205  			method: "PUT",
   206  			bodyFunc: func() (io.Reader, func(), error) {
   207  				r, cleanup, err := newBufferFunc()
   208  				return io.NopCloser(r), cleanup, err
   209  			},
   210  			contentLength:  nBytes,
   211  			limitedReader:  true,
   212  			expectedReader: bufferType,
   213  		},
   214  		{
   215  			name:          "buffer, non-chunked, negative size",
   216  			method:        "PUT",
   217  			bodyFunc:      newBufferFunc,
   218  			contentLength: -1,
   219  			expectedWrite: true,
   220  		},
   221  		{
   222  			name:          "buffer, non-chunked, CONNECT, negative size",
   223  			method:        "CONNECT",
   224  			bodyFunc:      newBufferFunc,
   225  			contentLength: -1,
   226  			expectedWrite: true,
   227  		},
   228  		{
   229  			name:             "buffer, chunked",
   230  			method:           "PUT",
   231  			bodyFunc:         newBufferFunc,
   232  			transferEncoding: []string{"chunked"},
   233  			expectedWrite:    true,
   234  		},
   235  	}
   236  
   237  	for _, tc := range cases {
   238  		t.Run(tc.name, func(t *testing.T) {
   239  			body, cleanup, err := tc.bodyFunc()
   240  			if err != nil {
   241  				t.Fatal(err)
   242  			}
   243  			defer cleanup()
   244  
   245  			mw := &mockTransferWriter{}
   246  			tw := &transferWriter{
   247  				Body:             body,
   248  				ContentLength:    tc.contentLength,
   249  				TransferEncoding: tc.transferEncoding,
   250  			}
   251  
   252  			if err := tw.writeBody(mw); err != nil {
   253  				t.Fatal(err)
   254  			}
   255  
   256  			if tc.expectedReader != nil {
   257  				if mw.CalledReader == nil {
   258  					t.Fatal("did not call ReadFrom")
   259  				}
   260  
   261  				var actualReader reflect.Type
   262  				lr, ok := mw.CalledReader.(*io.LimitedReader)
   263  				if ok && tc.limitedReader {
   264  					actualReader = reflect.TypeOf(lr.R)
   265  				} else {
   266  					actualReader = reflect.TypeOf(mw.CalledReader)
   267  					// We have to handle this special case for genericWriteTo in os,
   268  					// this struct is introduced to support a zero-copy optimization,
   269  					// check out https://go.dev/issue/58808 for details.
   270  					if actualReader.Kind() == reflect.Struct && actualReader.PkgPath() == "os" && actualReader.Name() == "fileWithoutWriteTo" {
   271  						actualReader = actualReader.Field(1).Type
   272  					}
   273  				}
   274  
   275  				if tc.expectedReader != actualReader {
   276  					t.Fatalf("got reader %s want %s", actualReader, tc.expectedReader)
   277  				}
   278  			}
   279  
   280  			if tc.expectedWrite && !mw.WriteCalled {
   281  				t.Fatal("did not invoke Write")
   282  			}
   283  		})
   284  	}
   285  }
   286  
   287  func TestParseTransferEncoding(t *testing.T) {
   288  	tests := []struct {
   289  		hdr     Header
   290  		wantErr error
   291  	}{
   292  		{
   293  			hdr:     Header{"Transfer-Encoding": {"fugazi"}},
   294  			wantErr: &unsupportedTEError{`unsupported transfer encoding: "fugazi"`},
   295  		},
   296  		{
   297  			hdr:     Header{"Transfer-Encoding": {"chunked, chunked", "identity", "chunked"}},
   298  			wantErr: &unsupportedTEError{`too many transfer encodings: ["chunked, chunked" "identity" "chunked"]`},
   299  		},
   300  		{
   301  			hdr:     Header{"Transfer-Encoding": {""}},
   302  			wantErr: &unsupportedTEError{`unsupported transfer encoding: ""`},
   303  		},
   304  		{
   305  			hdr:     Header{"Transfer-Encoding": {"chunked, identity"}},
   306  			wantErr: &unsupportedTEError{`unsupported transfer encoding: "chunked, identity"`},
   307  		},
   308  		{
   309  			hdr:     Header{"Transfer-Encoding": {"chunked", "identity"}},
   310  			wantErr: &unsupportedTEError{`too many transfer encodings: ["chunked" "identity"]`},
   311  		},
   312  		{
   313  			hdr:     Header{"Transfer-Encoding": {"\x0bchunked"}},
   314  			wantErr: &unsupportedTEError{`unsupported transfer encoding: "\vchunked"`},
   315  		},
   316  		{
   317  			hdr:     Header{"Transfer-Encoding": {"chunked"}},
   318  			wantErr: nil,
   319  		},
   320  	}
   321  
   322  	for i, tt := range tests {
   323  		tr := &transferReader{
   324  			Header:     tt.hdr,
   325  			ProtoMajor: 1,
   326  			ProtoMinor: 1,
   327  		}
   328  		gotErr := tr.parseTransferEncoding()
   329  		if !reflect.DeepEqual(gotErr, tt.wantErr) {
   330  			t.Errorf("%d.\ngot error:\n%v\nwant error:\n%v\n\n", i, gotErr, tt.wantErr)
   331  		}
   332  	}
   333  }
   334  
   335  // issue 39017 - disallow Content-Length values such as "+3"
   336  func TestParseContentLength(t *testing.T) {
   337  	tests := []struct {
   338  		cl      string
   339  		wantErr error
   340  	}{
   341  		{
   342  			cl:      "",
   343  			wantErr: badStringError("invalid empty Content-Length", ""),
   344  		},
   345  		{
   346  			cl:      "3",
   347  			wantErr: nil,
   348  		},
   349  		{
   350  			cl:      "+3",
   351  			wantErr: badStringError("bad Content-Length", "+3"),
   352  		},
   353  		{
   354  			cl:      "-3",
   355  			wantErr: badStringError("bad Content-Length", "-3"),
   356  		},
   357  		{
   358  			// max int64, for safe conversion before returning
   359  			cl:      "9223372036854775807",
   360  			wantErr: nil,
   361  		},
   362  		{
   363  			cl:      "9223372036854775808",
   364  			wantErr: badStringError("bad Content-Length", "9223372036854775808"),
   365  		},
   366  	}
   367  
   368  	for _, tt := range tests {
   369  		if _, gotErr := parseContentLength([]string{tt.cl}); !reflect.DeepEqual(gotErr, tt.wantErr) {
   370  			t.Errorf("%q:\n\tgot=%v\n\twant=%v", tt.cl, gotErr, tt.wantErr)
   371  		}
   372  	}
   373  }
   374  

View as plain text