// Copyright 2016 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package net import ( "bytes" "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "io" "os" "runtime" "sync" "testing" "time" ) const ( newton = "../testdata/Isaac.Newton-Opticks.txt" newtonLen = 567198 newtonSHA256 = "d4a9ac22462b35e7821a4f2706c211093da678620a8f9997989ee7cf8d507bbd" ) func TestSendfile(t *testing.T) { ln := newLocalListener(t, "tcp") defer ln.Close() errc := make(chan error, 1) go func(ln Listener) { // Wait for a connection. conn, err := ln.Accept() if err != nil { errc <- err close(errc) return } go func() { defer close(errc) defer conn.Close() f, err := os.Open(newton) if err != nil { errc <- err return } defer f.Close() // Return file data using io.Copy, which should use // sendFile if available. sbytes, err := io.Copy(conn, f) if err != nil { errc <- err return } if sbytes != newtonLen { errc <- fmt.Errorf("sent %d bytes; expected %d", sbytes, newtonLen) return } }() }(ln) // Connect to listener to retrieve file and verify digest matches // expected. c, err := Dial("tcp", ln.Addr().String()) if err != nil { t.Fatal(err) } defer c.Close() h := sha256.New() rbytes, err := io.Copy(h, c) if err != nil { t.Error(err) } if rbytes != newtonLen { t.Errorf("received %d bytes; expected %d", rbytes, newtonLen) } if res := hex.EncodeToString(h.Sum(nil)); res != newtonSHA256 { t.Error("retrieved data hash did not match") } for err := range errc { t.Error(err) } } func TestSendfileParts(t *testing.T) { ln := newLocalListener(t, "tcp") defer ln.Close() errc := make(chan error, 1) go func(ln Listener) { // Wait for a connection. conn, err := ln.Accept() if err != nil { errc <- err close(errc) return } go func() { defer close(errc) defer conn.Close() f, err := os.Open(newton) if err != nil { errc <- err return } defer f.Close() for i := 0; i < 3; i++ { // Return file data using io.CopyN, which should use // sendFile if available. _, err = io.CopyN(conn, f, 3) if err != nil { errc <- err return } } }() }(ln) c, err := Dial("tcp", ln.Addr().String()) if err != nil { t.Fatal(err) } defer c.Close() buf := new(bytes.Buffer) buf.ReadFrom(c) if want, have := "Produced ", buf.String(); have != want { t.Errorf("unexpected server reply %q, want %q", have, want) } for err := range errc { t.Error(err) } } func TestSendfileSeeked(t *testing.T) { ln := newLocalListener(t, "tcp") defer ln.Close() const seekTo = 65 << 10 const sendSize = 10 << 10 errc := make(chan error, 1) go func(ln Listener) { // Wait for a connection. conn, err := ln.Accept() if err != nil { errc <- err close(errc) return } go func() { defer close(errc) defer conn.Close() f, err := os.Open(newton) if err != nil { errc <- err return } defer f.Close() if _, err := f.Seek(seekTo, io.SeekStart); err != nil { errc <- err return } _, err = io.CopyN(conn, f, sendSize) if err != nil { errc <- err return } }() }(ln) c, err := Dial("tcp", ln.Addr().String()) if err != nil { t.Fatal(err) } defer c.Close() buf := new(bytes.Buffer) buf.ReadFrom(c) if buf.Len() != sendSize { t.Errorf("Got %d bytes; want %d", buf.Len(), sendSize) } for err := range errc { t.Error(err) } } // Test that sendfile doesn't put a pipe into blocking mode. func TestSendfilePipe(t *testing.T) { switch runtime.GOOS { case "plan9", "windows", "js", "wasip1": // These systems don't support deadlines on pipes. t.Skipf("skipping on %s", runtime.GOOS) } t.Parallel() ln := newLocalListener(t, "tcp") defer ln.Close() r, w, err := os.Pipe() if err != nil { t.Fatal(err) } defer w.Close() defer r.Close() copied := make(chan bool) var wg sync.WaitGroup wg.Add(1) go func() { // Accept a connection and copy 1 byte from the read end of // the pipe to the connection. This will call into sendfile. defer wg.Done() conn, err := ln.Accept() if err != nil { t.Error(err) return } defer conn.Close() _, err = io.CopyN(conn, r, 1) if err != nil { t.Error(err) return } // Signal the main goroutine that we've copied the byte. close(copied) }() wg.Add(1) go func() { // Write 1 byte to the write end of the pipe. defer wg.Done() _, err := w.Write([]byte{'a'}) if err != nil { t.Error(err) } }() wg.Add(1) go func() { // Connect to the server started two goroutines up and // discard any data that it writes. defer wg.Done() conn, err := Dial("tcp", ln.Addr().String()) if err != nil { t.Error(err) return } defer conn.Close() io.Copy(io.Discard, conn) }() // Wait for the byte to be copied, meaning that sendfile has // been called on the pipe. <-copied // Set a very short deadline on the read end of the pipe. if err := r.SetDeadline(time.Now().Add(time.Microsecond)); err != nil { t.Fatal(err) } wg.Add(1) go func() { // Wait for much longer than the deadline and write a byte // to the pipe. defer wg.Done() time.Sleep(50 * time.Millisecond) w.Write([]byte{'b'}) }() // If this read does not time out, the pipe was incorrectly // put into blocking mode. _, err = r.Read(make([]byte, 1)) if err == nil { t.Error("Read did not time out") } else if !os.IsTimeout(err) { t.Errorf("got error %v, expected a time out", err) } wg.Wait() } // Issue 43822: tests that returns EOF when conn write timeout. func TestSendfileOnWriteTimeoutExceeded(t *testing.T) { ln := newLocalListener(t, "tcp") defer ln.Close() errc := make(chan error, 1) go func(ln Listener) (retErr error) { defer func() { errc <- retErr close(errc) }() conn, err := ln.Accept() if err != nil { return err } defer conn.Close() // Set the write deadline in the past(1h ago). It makes // sure that it is always write timeout. if err := conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)); err != nil { return err } f, err := os.Open(newton) if err != nil { return err } defer f.Close() _, err = io.Copy(conn, f) if errors.Is(err, os.ErrDeadlineExceeded) { return nil } if err == nil { err = fmt.Errorf("expected ErrDeadlineExceeded, but got nil") } return err }(ln) conn, err := Dial("tcp", ln.Addr().String()) if err != nil { t.Fatal(err) } defer conn.Close() n, err := io.Copy(io.Discard, conn) if err != nil { t.Fatalf("expected nil error, but got %v", err) } if n != 0 { t.Fatalf("expected receive zero, but got %d byte(s)", n) } if err := <-errc; err != nil { t.Fatal(err) } } func BenchmarkSendfileZeroBytes(b *testing.B) { var ( wg sync.WaitGroup ctx, cancel = context.WithCancel(context.Background()) ) defer wg.Wait() ln := newLocalListener(b, "tcp") defer ln.Close() tempFile, err := os.CreateTemp(b.TempDir(), "test.txt") if err != nil { b.Fatalf("failed to create temp file: %v", err) } defer tempFile.Close() fileName := tempFile.Name() dataSize := b.N wg.Add(1) go func(f *os.File) { defer wg.Done() for i := 0; i < dataSize; i++ { if _, err := f.Write([]byte{1}); err != nil { b.Errorf("failed to write: %v", err) return } if i%1000 == 0 { f.Sync() } } }(tempFile) b.ResetTimer() b.ReportAllocs() wg.Add(1) go func(ln Listener, fileName string) { defer wg.Done() conn, err := ln.Accept() if err != nil { b.Errorf("failed to accept: %v", err) return } defer conn.Close() f, err := os.OpenFile(fileName, os.O_RDONLY, 0660) if err != nil { b.Errorf("failed to open file: %v", err) return } defer f.Close() for { if ctx.Err() != nil { return } if _, err := io.Copy(conn, f); err != nil { b.Errorf("failed to copy: %v", err) return } } }(ln, fileName) conn, err := Dial("tcp", ln.Addr().String()) if err != nil { b.Fatalf("failed to dial: %v", err) } defer conn.Close() n, err := io.CopyN(io.Discard, conn, int64(dataSize)) if err != nil { b.Fatalf("failed to copy: %v", err) } if n != int64(dataSize) { b.Fatalf("expected %d copied bytes, but got %d", dataSize, n) } cancel() }