Source file
src/net/splice_test.go
1
2
3
4
5
6
7 package net
8
9 import (
10 "io"
11 "log"
12 "os"
13 "os/exec"
14 "strconv"
15 "sync"
16 "testing"
17 "time"
18 )
19
20 func TestSplice(t *testing.T) {
21 t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
22 if !testableNetwork("unixgram") {
23 t.Skip("skipping unix-to-tcp tests")
24 }
25 t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
26 t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") })
27 t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") })
28 t.Run("no-unixpacket", testSpliceNoUnixpacket)
29 t.Run("no-unixgram", testSpliceNoUnixgram)
30 }
31
32 func testSpliceToFile(t *testing.T, upNet, downNet string) {
33 t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.testFile)
34 t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.testFile)
35 t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.testFile)
36 t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.testFile)
37 t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.testFile)
38 t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.testFile)
39 }
40
41 func testSplice(t *testing.T, upNet, downNet string) {
42 t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
43 t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
44 t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
45 t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
46 t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
47 t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
48 t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
49 t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
50 }
51
52 type spliceTestCase struct {
53 upNet, downNet string
54
55 chunkSize, totalSize int
56 limitReadSize int
57 }
58
59 func (tc spliceTestCase) test(t *testing.T) {
60 clientUp, serverUp := spliceTestSocketPair(t, tc.upNet)
61 defer serverUp.Close()
62 cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize)
63 if err != nil {
64 t.Fatal(err)
65 }
66 defer cleanup()
67 clientDown, serverDown := spliceTestSocketPair(t, tc.downNet)
68 defer serverDown.Close()
69 cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize)
70 if err != nil {
71 t.Fatal(err)
72 }
73 defer cleanup()
74 var (
75 r io.Reader = serverUp
76 size = tc.totalSize
77 )
78 if tc.limitReadSize > 0 {
79 if tc.limitReadSize < size {
80 size = tc.limitReadSize
81 }
82
83 r = &io.LimitedReader{
84 N: int64(tc.limitReadSize),
85 R: serverUp,
86 }
87 defer serverUp.Close()
88 }
89 n, err := io.Copy(serverDown, r)
90 serverDown.Close()
91 if err != nil {
92 t.Fatal(err)
93 }
94 if want := int64(size); want != n {
95 t.Errorf("want %d bytes spliced, got %d", want, n)
96 }
97
98 if tc.limitReadSize > 0 {
99 wantN := 0
100 if tc.limitReadSize > size {
101 wantN = tc.limitReadSize - size
102 }
103
104 if n := r.(*io.LimitedReader).N; n != int64(wantN) {
105 t.Errorf("r.N = %d, want %d", n, wantN)
106 }
107 }
108 }
109
110 func (tc spliceTestCase) testFile(t *testing.T) {
111 f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
112 if err != nil {
113 t.Fatal(err)
114 }
115 defer f.Close()
116
117 client, server := spliceTestSocketPair(t, tc.upNet)
118 defer server.Close()
119
120 cleanup, err := startSpliceClient(client, "w", tc.chunkSize, tc.totalSize)
121 if err != nil {
122 client.Close()
123 t.Fatal("failed to start splice client:", err)
124 }
125 defer cleanup()
126
127 var (
128 r io.Reader = server
129 actualSize = tc.totalSize
130 )
131 if tc.limitReadSize > 0 {
132 if tc.limitReadSize < actualSize {
133 actualSize = tc.limitReadSize
134 }
135
136 r = &io.LimitedReader{
137 N: int64(tc.limitReadSize),
138 R: r,
139 }
140 }
141
142 got, err := io.Copy(f, r)
143 if err != nil {
144 t.Fatalf("failed to ReadFrom with error: %v", err)
145 }
146 if want := int64(actualSize); got != want {
147 t.Errorf("got %d bytes, want %d", got, want)
148 }
149 if tc.limitReadSize > 0 {
150 wantN := 0
151 if tc.limitReadSize > actualSize {
152 wantN = tc.limitReadSize - actualSize
153 }
154
155 if gotN := r.(*io.LimitedReader).N; gotN != int64(wantN) {
156 t.Errorf("r.N = %d, want %d", gotN, wantN)
157 }
158 }
159 }
160
161 func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
162 clientUp, serverUp := spliceTestSocketPair(t, upNet)
163 defer clientUp.Close()
164 clientDown, serverDown := spliceTestSocketPair(t, downNet)
165 defer clientDown.Close()
166
167 serverUp.Close()
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185 msg := "bye"
186 go func() {
187 serverDown.(io.ReaderFrom).ReadFrom(serverUp)
188 io.WriteString(serverDown, msg)
189 serverDown.Close()
190 }()
191
192 buf := make([]byte, 3)
193 _, err := io.ReadFull(clientDown, buf)
194 if err != nil {
195 t.Errorf("clientDown: %v", err)
196 }
197 if string(buf) != msg {
198 t.Errorf("clientDown got %q, want %q", buf, msg)
199 }
200 }
201
202 func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
203 front := newLocalListener(t, upNet)
204 defer front.Close()
205 back := newLocalListener(t, downNet)
206 defer back.Close()
207
208 var wg sync.WaitGroup
209 wg.Add(2)
210
211 proxy := func() {
212 src, err := front.Accept()
213 if err != nil {
214 return
215 }
216 dst, err := Dial(downNet, back.Addr().String())
217 if err != nil {
218 return
219 }
220 defer dst.Close()
221 defer src.Close()
222 go func() {
223 io.Copy(src, dst)
224 wg.Done()
225 }()
226 go func() {
227 io.Copy(dst, src)
228 wg.Done()
229 }()
230 }
231
232 go proxy()
233
234 toFront, err := Dial(upNet, front.Addr().String())
235 if err != nil {
236 t.Fatal(err)
237 }
238
239 io.WriteString(toFront, "foo")
240 toFront.Close()
241
242 fromProxy, err := back.Accept()
243 if err != nil {
244 t.Fatal(err)
245 }
246 defer fromProxy.Close()
247
248 _, err = io.ReadAll(fromProxy)
249 if err != nil {
250 t.Fatal(err)
251 }
252
253 wg.Wait()
254 }
255
256 func testSpliceNoUnixpacket(t *testing.T) {
257 clientUp, serverUp := spliceTestSocketPair(t, "unixpacket")
258 defer clientUp.Close()
259 defer serverUp.Close()
260 clientDown, serverDown := spliceTestSocketPair(t, "tcp")
261 defer clientDown.Close()
262 defer serverDown.Close()
263
264
265
266
267
268
269
270
271 _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
272 if err != nil || handled != false {
273 t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
274 }
275 }
276
277 func testSpliceNoUnixgram(t *testing.T) {
278 addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t))
279 if err != nil {
280 t.Fatal(err)
281 }
282 defer os.Remove(addr.Name)
283 up, err := ListenUnixgram("unixgram", addr)
284 if err != nil {
285 t.Fatal(err)
286 }
287 defer up.Close()
288 clientDown, serverDown := spliceTestSocketPair(t, "tcp")
289 defer clientDown.Close()
290 defer serverDown.Close()
291
292 _, err, handled := splice(serverDown.(*TCPConn).fd, up)
293 if err != nil || handled != false {
294 t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
295 }
296 }
297
298 func BenchmarkSplice(b *testing.B) {
299 testHookUninstaller.Do(uninstallTestHooks)
300
301 b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
302 b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
303 }
304
305 func benchSplice(b *testing.B, upNet, downNet string) {
306 for i := 0; i <= 10; i++ {
307 chunkSize := 1 << uint(i+10)
308 tc := spliceTestCase{
309 upNet: upNet,
310 downNet: downNet,
311 chunkSize: chunkSize,
312 }
313
314 b.Run(strconv.Itoa(chunkSize), tc.bench)
315 }
316 }
317
318 func (tc spliceTestCase) bench(b *testing.B) {
319
320 useSplice := true
321
322 clientUp, serverUp := spliceTestSocketPair(b, tc.upNet)
323 defer serverUp.Close()
324
325 cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
326 if err != nil {
327 b.Fatal(err)
328 }
329 defer cleanup()
330
331 clientDown, serverDown := spliceTestSocketPair(b, tc.downNet)
332 defer serverDown.Close()
333
334 cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
335 if err != nil {
336 b.Fatal(err)
337 }
338 defer cleanup()
339
340 b.SetBytes(int64(tc.chunkSize))
341 b.ResetTimer()
342
343 if useSplice {
344 _, err := io.Copy(serverDown, serverUp)
345 if err != nil {
346 b.Fatal(err)
347 }
348 } else {
349 type onlyReader struct {
350 io.Reader
351 }
352 _, err := io.Copy(serverDown, onlyReader{serverUp})
353 if err != nil {
354 b.Fatal(err)
355 }
356 }
357 }
358
359 func spliceTestSocketPair(t testing.TB, net string) (client, server Conn) {
360 t.Helper()
361 ln := newLocalListener(t, net)
362 defer ln.Close()
363 var cerr, serr error
364 acceptDone := make(chan struct{})
365 go func() {
366 server, serr = ln.Accept()
367 acceptDone <- struct{}{}
368 }()
369 client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
370 <-acceptDone
371 if cerr != nil {
372 if server != nil {
373 server.Close()
374 }
375 t.Fatal(cerr)
376 }
377 if serr != nil {
378 if client != nil {
379 client.Close()
380 }
381 t.Fatal(serr)
382 }
383 return client, server
384 }
385
386 func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) {
387 f, err := conn.(interface{ File() (*os.File, error) }).File()
388 if err != nil {
389 return nil, err
390 }
391
392 cmd := exec.Command(os.Args[0], os.Args[1:]...)
393 cmd.Env = []string{
394 "GO_NET_TEST_SPLICE=1",
395 "GO_NET_TEST_SPLICE_OP=" + op,
396 "GO_NET_TEST_SPLICE_CHUNK_SIZE=" + strconv.Itoa(chunkSize),
397 "GO_NET_TEST_SPLICE_TOTAL_SIZE=" + strconv.Itoa(totalSize),
398 "TMPDIR=" + os.Getenv("TMPDIR"),
399 }
400 cmd.ExtraFiles = append(cmd.ExtraFiles, f)
401 cmd.Stdout = os.Stdout
402 cmd.Stderr = os.Stderr
403
404 if err := cmd.Start(); err != nil {
405 return nil, err
406 }
407
408 donec := make(chan struct{})
409 go func() {
410 cmd.Wait()
411 conn.Close()
412 f.Close()
413 close(donec)
414 }()
415
416 return func() {
417 select {
418 case <-donec:
419 case <-time.After(5 * time.Second):
420 log.Printf("killing splice client after 5 second shutdown timeout")
421 cmd.Process.Kill()
422 select {
423 case <-donec:
424 case <-time.After(5 * time.Second):
425 log.Printf("splice client didn't die after 10 seconds")
426 }
427 }
428 }, nil
429 }
430
431 func init() {
432 if os.Getenv("GO_NET_TEST_SPLICE") == "" {
433 return
434 }
435 defer os.Exit(0)
436
437 f := os.NewFile(uintptr(3), "splice-test-conn")
438 defer f.Close()
439
440 conn, err := FileConn(f)
441 if err != nil {
442 log.Fatal(err)
443 }
444
445 var chunkSize int
446 if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_CHUNK_SIZE")); err != nil {
447 log.Fatal(err)
448 }
449 buf := make([]byte, chunkSize)
450
451 var totalSize int
452 if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_TOTAL_SIZE")); err != nil {
453 log.Fatal(err)
454 }
455
456 var fn func([]byte) (int, error)
457 switch op := os.Getenv("GO_NET_TEST_SPLICE_OP"); op {
458 case "r":
459 fn = conn.Read
460 case "w":
461 defer conn.Close()
462
463 fn = conn.Write
464 default:
465 log.Fatalf("unknown op %q", op)
466 }
467
468 var n int
469 for count := 0; count < totalSize; count += n {
470 if count+chunkSize > totalSize {
471 buf = buf[:totalSize-count]
472 }
473
474 var err error
475 if n, err = fn(buf); err != nil {
476 return
477 }
478 }
479 }
480
481 func BenchmarkSpliceFile(b *testing.B) {
482 b.Run("tcp-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "tcp") })
483 b.Run("unix-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "unix") })
484 }
485
486 func benchmarkSpliceFile(b *testing.B, proto string) {
487 for i := 0; i <= 10; i++ {
488 size := 1 << (i + 10)
489 bench := spliceFileBench{
490 proto: proto,
491 chunkSize: size,
492 }
493 b.Run(strconv.Itoa(size), bench.benchSpliceFile)
494 }
495 }
496
497 type spliceFileBench struct {
498 proto string
499 chunkSize int
500 }
501
502 func (bench spliceFileBench) benchSpliceFile(b *testing.B) {
503 f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
504 if err != nil {
505 b.Fatal(err)
506 }
507 defer f.Close()
508
509 totalSize := b.N * bench.chunkSize
510
511 client, server := spliceTestSocketPair(b, bench.proto)
512 defer server.Close()
513
514 cleanup, err := startSpliceClient(client, "w", bench.chunkSize, totalSize)
515 if err != nil {
516 client.Close()
517 b.Fatalf("failed to start splice client: %v", err)
518 }
519 defer cleanup()
520
521 b.ReportAllocs()
522 b.SetBytes(int64(bench.chunkSize))
523 b.ResetTimer()
524
525 got, err := io.Copy(f, server)
526 if err != nil {
527 b.Fatalf("failed to ReadFrom with error: %v", err)
528 }
529 if want := int64(totalSize); got != want {
530 b.Errorf("bytes sent mismatch, got: %d, want: %d", got, want)
531 }
532 }
533
View as plain text