Source file src/net/splice_test.go

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build linux
     6  
     7  package net
     8  
     9  import (
    10  	"io"
    11  	"log"
    12  	"os"
    13  	"os/exec"
    14  	"strconv"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  )
    19  
    20  func TestSplice(t *testing.T) {
    21  	t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
    22  	if !testableNetwork("unixgram") {
    23  		t.Skip("skipping unix-to-tcp tests")
    24  	}
    25  	t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
    26  	t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") })
    27  	t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") })
    28  	t.Run("no-unixpacket", testSpliceNoUnixpacket)
    29  	t.Run("no-unixgram", testSpliceNoUnixgram)
    30  }
    31  
    32  func testSpliceToFile(t *testing.T, upNet, downNet string) {
    33  	t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.testFile)
    34  	t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.testFile)
    35  	t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.testFile)
    36  	t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.testFile)
    37  	t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.testFile)
    38  	t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.testFile)
    39  }
    40  
    41  func testSplice(t *testing.T, upNet, downNet string) {
    42  	t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
    43  	t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
    44  	t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
    45  	t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
    46  	t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
    47  	t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
    48  	t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
    49  	t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
    50  }
    51  
    52  type spliceTestCase struct {
    53  	upNet, downNet string
    54  
    55  	chunkSize, totalSize int
    56  	limitReadSize        int
    57  }
    58  
    59  func (tc spliceTestCase) test(t *testing.T) {
    60  	clientUp, serverUp := spliceTestSocketPair(t, tc.upNet)
    61  	defer serverUp.Close()
    62  	cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize)
    63  	if err != nil {
    64  		t.Fatal(err)
    65  	}
    66  	defer cleanup()
    67  	clientDown, serverDown := spliceTestSocketPair(t, tc.downNet)
    68  	defer serverDown.Close()
    69  	cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize)
    70  	if err != nil {
    71  		t.Fatal(err)
    72  	}
    73  	defer cleanup()
    74  	var (
    75  		r    io.Reader = serverUp
    76  		size           = tc.totalSize
    77  	)
    78  	if tc.limitReadSize > 0 {
    79  		if tc.limitReadSize < size {
    80  			size = tc.limitReadSize
    81  		}
    82  
    83  		r = &io.LimitedReader{
    84  			N: int64(tc.limitReadSize),
    85  			R: serverUp,
    86  		}
    87  		defer serverUp.Close()
    88  	}
    89  	n, err := io.Copy(serverDown, r)
    90  	serverDown.Close()
    91  	if err != nil {
    92  		t.Fatal(err)
    93  	}
    94  	if want := int64(size); want != n {
    95  		t.Errorf("want %d bytes spliced, got %d", want, n)
    96  	}
    97  
    98  	if tc.limitReadSize > 0 {
    99  		wantN := 0
   100  		if tc.limitReadSize > size {
   101  			wantN = tc.limitReadSize - size
   102  		}
   103  
   104  		if n := r.(*io.LimitedReader).N; n != int64(wantN) {
   105  			t.Errorf("r.N = %d, want %d", n, wantN)
   106  		}
   107  	}
   108  }
   109  
   110  func (tc spliceTestCase) testFile(t *testing.T) {
   111  	f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
   112  	if err != nil {
   113  		t.Fatal(err)
   114  	}
   115  	defer f.Close()
   116  
   117  	client, server := spliceTestSocketPair(t, tc.upNet)
   118  	defer server.Close()
   119  
   120  	cleanup, err := startSpliceClient(client, "w", tc.chunkSize, tc.totalSize)
   121  	if err != nil {
   122  		client.Close()
   123  		t.Fatal("failed to start splice client:", err)
   124  	}
   125  	defer cleanup()
   126  
   127  	var (
   128  		r          io.Reader = server
   129  		actualSize           = tc.totalSize
   130  	)
   131  	if tc.limitReadSize > 0 {
   132  		if tc.limitReadSize < actualSize {
   133  			actualSize = tc.limitReadSize
   134  		}
   135  
   136  		r = &io.LimitedReader{
   137  			N: int64(tc.limitReadSize),
   138  			R: r,
   139  		}
   140  	}
   141  
   142  	got, err := io.Copy(f, r)
   143  	if err != nil {
   144  		t.Fatalf("failed to ReadFrom with error: %v", err)
   145  	}
   146  	if want := int64(actualSize); got != want {
   147  		t.Errorf("got %d bytes, want %d", got, want)
   148  	}
   149  	if tc.limitReadSize > 0 {
   150  		wantN := 0
   151  		if tc.limitReadSize > actualSize {
   152  			wantN = tc.limitReadSize - actualSize
   153  		}
   154  
   155  		if gotN := r.(*io.LimitedReader).N; gotN != int64(wantN) {
   156  			t.Errorf("r.N = %d, want %d", gotN, wantN)
   157  		}
   158  	}
   159  }
   160  
   161  func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
   162  	clientUp, serverUp := spliceTestSocketPair(t, upNet)
   163  	defer clientUp.Close()
   164  	clientDown, serverDown := spliceTestSocketPair(t, downNet)
   165  	defer clientDown.Close()
   166  
   167  	serverUp.Close()
   168  
   169  	// We'd like to call net.splice here and check the handled return
   170  	// value, but we disable splice on old Linux kernels.
   171  	//
   172  	// In that case, poll.Splice and net.splice return a non-nil error
   173  	// and handled == false. We'd ideally like to see handled == true
   174  	// because the source reader is at EOF, but if we're running on an old
   175  	// kernel, and splice is disabled, we won't see EOF from net.splice,
   176  	// because we won't touch the reader at all.
   177  	//
   178  	// Trying to untangle the errors from net.splice and match them
   179  	// against the errors created by the poll package would be brittle,
   180  	// so this is a higher level test.
   181  	//
   182  	// The following ReadFrom should return immediately, regardless of
   183  	// whether splice is disabled or not. The other side should then
   184  	// get a goodbye signal. Test for the goodbye signal.
   185  	msg := "bye"
   186  	go func() {
   187  		serverDown.(io.ReaderFrom).ReadFrom(serverUp)
   188  		io.WriteString(serverDown, msg)
   189  		serverDown.Close()
   190  	}()
   191  
   192  	buf := make([]byte, 3)
   193  	_, err := io.ReadFull(clientDown, buf)
   194  	if err != nil {
   195  		t.Errorf("clientDown: %v", err)
   196  	}
   197  	if string(buf) != msg {
   198  		t.Errorf("clientDown got %q, want %q", buf, msg)
   199  	}
   200  }
   201  
   202  func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
   203  	front := newLocalListener(t, upNet)
   204  	defer front.Close()
   205  	back := newLocalListener(t, downNet)
   206  	defer back.Close()
   207  
   208  	var wg sync.WaitGroup
   209  	wg.Add(2)
   210  
   211  	proxy := func() {
   212  		src, err := front.Accept()
   213  		if err != nil {
   214  			return
   215  		}
   216  		dst, err := Dial(downNet, back.Addr().String())
   217  		if err != nil {
   218  			return
   219  		}
   220  		defer dst.Close()
   221  		defer src.Close()
   222  		go func() {
   223  			io.Copy(src, dst)
   224  			wg.Done()
   225  		}()
   226  		go func() {
   227  			io.Copy(dst, src)
   228  			wg.Done()
   229  		}()
   230  	}
   231  
   232  	go proxy()
   233  
   234  	toFront, err := Dial(upNet, front.Addr().String())
   235  	if err != nil {
   236  		t.Fatal(err)
   237  	}
   238  
   239  	io.WriteString(toFront, "foo")
   240  	toFront.Close()
   241  
   242  	fromProxy, err := back.Accept()
   243  	if err != nil {
   244  		t.Fatal(err)
   245  	}
   246  	defer fromProxy.Close()
   247  
   248  	_, err = io.ReadAll(fromProxy)
   249  	if err != nil {
   250  		t.Fatal(err)
   251  	}
   252  
   253  	wg.Wait()
   254  }
   255  
   256  func testSpliceNoUnixpacket(t *testing.T) {
   257  	clientUp, serverUp := spliceTestSocketPair(t, "unixpacket")
   258  	defer clientUp.Close()
   259  	defer serverUp.Close()
   260  	clientDown, serverDown := spliceTestSocketPair(t, "tcp")
   261  	defer clientDown.Close()
   262  	defer serverDown.Close()
   263  	// If splice called poll.Splice here, we'd get err == syscall.EINVAL
   264  	// and handled == false.  If poll.Splice gets an EINVAL on the first
   265  	// try, it assumes the kernel it's running on doesn't support splice
   266  	// for unix sockets and returns handled == false. This works for our
   267  	// purposes by somewhat of an accident, but is not entirely correct.
   268  	//
   269  	// What we want is err == nil and handled == false, i.e. we never
   270  	// called poll.Splice, because we know the unix socket's network.
   271  	_, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
   272  	if err != nil || handled != false {
   273  		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
   274  	}
   275  }
   276  
   277  func testSpliceNoUnixgram(t *testing.T) {
   278  	addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t))
   279  	if err != nil {
   280  		t.Fatal(err)
   281  	}
   282  	defer os.Remove(addr.Name)
   283  	up, err := ListenUnixgram("unixgram", addr)
   284  	if err != nil {
   285  		t.Fatal(err)
   286  	}
   287  	defer up.Close()
   288  	clientDown, serverDown := spliceTestSocketPair(t, "tcp")
   289  	defer clientDown.Close()
   290  	defer serverDown.Close()
   291  	// Analogous to testSpliceNoUnixpacket.
   292  	_, err, handled := splice(serverDown.(*TCPConn).fd, up)
   293  	if err != nil || handled != false {
   294  		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
   295  	}
   296  }
   297  
   298  func BenchmarkSplice(b *testing.B) {
   299  	testHookUninstaller.Do(uninstallTestHooks)
   300  
   301  	b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
   302  	b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
   303  }
   304  
   305  func benchSplice(b *testing.B, upNet, downNet string) {
   306  	for i := 0; i <= 10; i++ {
   307  		chunkSize := 1 << uint(i+10)
   308  		tc := spliceTestCase{
   309  			upNet:     upNet,
   310  			downNet:   downNet,
   311  			chunkSize: chunkSize,
   312  		}
   313  
   314  		b.Run(strconv.Itoa(chunkSize), tc.bench)
   315  	}
   316  }
   317  
   318  func (tc spliceTestCase) bench(b *testing.B) {
   319  	// To benchmark the genericReadFrom code path, set this to false.
   320  	useSplice := true
   321  
   322  	clientUp, serverUp := spliceTestSocketPair(b, tc.upNet)
   323  	defer serverUp.Close()
   324  
   325  	cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
   326  	if err != nil {
   327  		b.Fatal(err)
   328  	}
   329  	defer cleanup()
   330  
   331  	clientDown, serverDown := spliceTestSocketPair(b, tc.downNet)
   332  	defer serverDown.Close()
   333  
   334  	cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
   335  	if err != nil {
   336  		b.Fatal(err)
   337  	}
   338  	defer cleanup()
   339  
   340  	b.SetBytes(int64(tc.chunkSize))
   341  	b.ResetTimer()
   342  
   343  	if useSplice {
   344  		_, err := io.Copy(serverDown, serverUp)
   345  		if err != nil {
   346  			b.Fatal(err)
   347  		}
   348  	} else {
   349  		type onlyReader struct {
   350  			io.Reader
   351  		}
   352  		_, err := io.Copy(serverDown, onlyReader{serverUp})
   353  		if err != nil {
   354  			b.Fatal(err)
   355  		}
   356  	}
   357  }
   358  
   359  func spliceTestSocketPair(t testing.TB, net string) (client, server Conn) {
   360  	t.Helper()
   361  	ln := newLocalListener(t, net)
   362  	defer ln.Close()
   363  	var cerr, serr error
   364  	acceptDone := make(chan struct{})
   365  	go func() {
   366  		server, serr = ln.Accept()
   367  		acceptDone <- struct{}{}
   368  	}()
   369  	client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
   370  	<-acceptDone
   371  	if cerr != nil {
   372  		if server != nil {
   373  			server.Close()
   374  		}
   375  		t.Fatal(cerr)
   376  	}
   377  	if serr != nil {
   378  		if client != nil {
   379  			client.Close()
   380  		}
   381  		t.Fatal(serr)
   382  	}
   383  	return client, server
   384  }
   385  
   386  func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) {
   387  	f, err := conn.(interface{ File() (*os.File, error) }).File()
   388  	if err != nil {
   389  		return nil, err
   390  	}
   391  
   392  	cmd := exec.Command(os.Args[0], os.Args[1:]...)
   393  	cmd.Env = []string{
   394  		"GO_NET_TEST_SPLICE=1",
   395  		"GO_NET_TEST_SPLICE_OP=" + op,
   396  		"GO_NET_TEST_SPLICE_CHUNK_SIZE=" + strconv.Itoa(chunkSize),
   397  		"GO_NET_TEST_SPLICE_TOTAL_SIZE=" + strconv.Itoa(totalSize),
   398  		"TMPDIR=" + os.Getenv("TMPDIR"),
   399  	}
   400  	cmd.ExtraFiles = append(cmd.ExtraFiles, f)
   401  	cmd.Stdout = os.Stdout
   402  	cmd.Stderr = os.Stderr
   403  
   404  	if err := cmd.Start(); err != nil {
   405  		return nil, err
   406  	}
   407  
   408  	donec := make(chan struct{})
   409  	go func() {
   410  		cmd.Wait()
   411  		conn.Close()
   412  		f.Close()
   413  		close(donec)
   414  	}()
   415  
   416  	return func() {
   417  		select {
   418  		case <-donec:
   419  		case <-time.After(5 * time.Second):
   420  			log.Printf("killing splice client after 5 second shutdown timeout")
   421  			cmd.Process.Kill()
   422  			select {
   423  			case <-donec:
   424  			case <-time.After(5 * time.Second):
   425  				log.Printf("splice client didn't die after 10 seconds")
   426  			}
   427  		}
   428  	}, nil
   429  }
   430  
   431  func init() {
   432  	if os.Getenv("GO_NET_TEST_SPLICE") == "" {
   433  		return
   434  	}
   435  	defer os.Exit(0)
   436  
   437  	f := os.NewFile(uintptr(3), "splice-test-conn")
   438  	defer f.Close()
   439  
   440  	conn, err := FileConn(f)
   441  	if err != nil {
   442  		log.Fatal(err)
   443  	}
   444  
   445  	var chunkSize int
   446  	if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_CHUNK_SIZE")); err != nil {
   447  		log.Fatal(err)
   448  	}
   449  	buf := make([]byte, chunkSize)
   450  
   451  	var totalSize int
   452  	if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_TOTAL_SIZE")); err != nil {
   453  		log.Fatal(err)
   454  	}
   455  
   456  	var fn func([]byte) (int, error)
   457  	switch op := os.Getenv("GO_NET_TEST_SPLICE_OP"); op {
   458  	case "r":
   459  		fn = conn.Read
   460  	case "w":
   461  		defer conn.Close()
   462  
   463  		fn = conn.Write
   464  	default:
   465  		log.Fatalf("unknown op %q", op)
   466  	}
   467  
   468  	var n int
   469  	for count := 0; count < totalSize; count += n {
   470  		if count+chunkSize > totalSize {
   471  			buf = buf[:totalSize-count]
   472  		}
   473  
   474  		var err error
   475  		if n, err = fn(buf); err != nil {
   476  			return
   477  		}
   478  	}
   479  }
   480  
   481  func BenchmarkSpliceFile(b *testing.B) {
   482  	b.Run("tcp-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "tcp") })
   483  	b.Run("unix-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "unix") })
   484  }
   485  
   486  func benchmarkSpliceFile(b *testing.B, proto string) {
   487  	for i := 0; i <= 10; i++ {
   488  		size := 1 << (i + 10)
   489  		bench := spliceFileBench{
   490  			proto:     proto,
   491  			chunkSize: size,
   492  		}
   493  		b.Run(strconv.Itoa(size), bench.benchSpliceFile)
   494  	}
   495  }
   496  
   497  type spliceFileBench struct {
   498  	proto     string
   499  	chunkSize int
   500  }
   501  
   502  func (bench spliceFileBench) benchSpliceFile(b *testing.B) {
   503  	f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
   504  	if err != nil {
   505  		b.Fatal(err)
   506  	}
   507  	defer f.Close()
   508  
   509  	totalSize := b.N * bench.chunkSize
   510  
   511  	client, server := spliceTestSocketPair(b, bench.proto)
   512  	defer server.Close()
   513  
   514  	cleanup, err := startSpliceClient(client, "w", bench.chunkSize, totalSize)
   515  	if err != nil {
   516  		client.Close()
   517  		b.Fatalf("failed to start splice client: %v", err)
   518  	}
   519  	defer cleanup()
   520  
   521  	b.ReportAllocs()
   522  	b.SetBytes(int64(bench.chunkSize))
   523  	b.ResetTimer()
   524  
   525  	got, err := io.Copy(f, server)
   526  	if err != nil {
   527  		b.Fatalf("failed to ReadFrom with error: %v", err)
   528  	}
   529  	if want := int64(totalSize); got != want {
   530  		b.Errorf("bytes sent mismatch, got: %d, want: %d", got, want)
   531  	}
   532  }
   533  

View as plain text