// Copyright 2018 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // +build linux package net import ( "io" "io/ioutil" "log" "os" "os/exec" "strconv" "sync" "testing" "time" ) func TestSplice(t *testing.T) { t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") }) if !testableNetwork("unixgram") { t.Skip("skipping unix-to-tcp tests") } t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") }) t.Run("no-unixpacket", testSpliceNoUnixpacket) t.Run("no-unixgram", testSpliceNoUnixgram) } func testSplice(t *testing.T, upNet, downNet string) { t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test) t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test) t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test) t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test) t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test) t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test) t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) }) t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) }) } type spliceTestCase struct { upNet, downNet string chunkSize, totalSize int limitReadSize int } func (tc spliceTestCase) test(t *testing.T) { clientUp, serverUp, err := spliceTestSocketPair(tc.upNet) if err != nil { t.Fatal(err) } defer serverUp.Close() cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize) if err != nil { t.Fatal(err) } defer cleanup() clientDown, serverDown, err := spliceTestSocketPair(tc.downNet) if err != nil { t.Fatal(err) } defer serverDown.Close() cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize) if err != nil { t.Fatal(err) } defer cleanup() var ( r io.Reader = serverUp size = tc.totalSize ) if tc.limitReadSize > 0 { if tc.limitReadSize < size { size = tc.limitReadSize } r = &io.LimitedReader{ N: int64(tc.limitReadSize), R: serverUp, } defer serverUp.Close() } n, err := io.Copy(serverDown, r) serverDown.Close() if err != nil { t.Fatal(err) } if want := int64(size); want != n { t.Errorf("want %d bytes spliced, got %d", want, n) } if tc.limitReadSize > 0 { wantN := 0 if tc.limitReadSize > size { wantN = tc.limitReadSize - size } if n := r.(*io.LimitedReader).N; n != int64(wantN) { t.Errorf("r.N = %d, want %d", n, wantN) } } } func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) { clientUp, serverUp, err := spliceTestSocketPair(upNet) if err != nil { t.Fatal(err) } defer clientUp.Close() clientDown, serverDown, err := spliceTestSocketPair(downNet) if err != nil { t.Fatal(err) } defer clientDown.Close() serverUp.Close() // We'd like to call net.splice here and check the handled return // value, but we disable splice on old Linux kernels. // // In that case, poll.Splice and net.splice return a non-nil error // and handled == false. We'd ideally like to see handled == true // because the source reader is at EOF, but if we're running on an old // kernel, and splice is disabled, we won't see EOF from net.splice, // because we won't touch the reader at all. // // Trying to untangle the errors from net.splice and match them // against the errors created by the poll package would be brittle, // so this is a higher level test. // // The following ReadFrom should return immediately, regardless of // whether splice is disabled or not. The other side should then // get a goodbye signal. Test for the goodbye signal. msg := "bye" go func() { serverDown.(io.ReaderFrom).ReadFrom(serverUp) io.WriteString(serverDown, msg) serverDown.Close() }() buf := make([]byte, 3) _, err = io.ReadFull(clientDown, buf) if err != nil { t.Errorf("clientDown: %v", err) } if string(buf) != msg { t.Errorf("clientDown got %q, want %q", buf, msg) } } func testSpliceIssue25985(t *testing.T, upNet, downNet string) { front, err := newLocalListener(upNet) if err != nil { t.Fatal(err) } defer front.Close() back, err := newLocalListener(downNet) if err != nil { t.Fatal(err) } defer back.Close() var wg sync.WaitGroup wg.Add(2) proxy := func() { src, err := front.Accept() if err != nil { return } dst, err := Dial(downNet, back.Addr().String()) if err != nil { return } defer dst.Close() defer src.Close() go func() { io.Copy(src, dst) wg.Done() }() go func() { io.Copy(dst, src) wg.Done() }() } go proxy() toFront, err := Dial(upNet, front.Addr().String()) if err != nil { t.Fatal(err) } io.WriteString(toFront, "foo") toFront.Close() fromProxy, err := back.Accept() if err != nil { t.Fatal(err) } defer fromProxy.Close() _, err = ioutil.ReadAll(fromProxy) if err != nil { t.Fatal(err) } wg.Wait() } func testSpliceNoUnixpacket(t *testing.T) { clientUp, serverUp, err := spliceTestSocketPair("unixpacket") if err != nil { t.Fatal(err) } defer clientUp.Close() defer serverUp.Close() clientDown, serverDown, err := spliceTestSocketPair("tcp") if err != nil { t.Fatal(err) } defer clientDown.Close() defer serverDown.Close() // If splice called poll.Splice here, we'd get err == syscall.EINVAL // and handled == false. If poll.Splice gets an EINVAL on the first // try, it assumes the kernel it's running on doesn't support splice // for unix sockets and returns handled == false. This works for our // purposes by somewhat of an accident, but is not entirely correct. // // What we want is err == nil and handled == false, i.e. we never // called poll.Splice, because we know the unix socket's network. _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp) if err != nil || handled != false { t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled) } } func testSpliceNoUnixgram(t *testing.T) { addr, err := ResolveUnixAddr("unixgram", testUnixAddr()) if err != nil { t.Fatal(err) } defer os.Remove(addr.Name) up, err := ListenUnixgram("unixgram", addr) if err != nil { t.Fatal(err) } defer up.Close() clientDown, serverDown, err := spliceTestSocketPair("tcp") if err != nil { t.Fatal(err) } defer clientDown.Close() defer serverDown.Close() // Analogous to testSpliceNoUnixpacket. _, err, handled := splice(serverDown.(*TCPConn).fd, up) if err != nil || handled != false { t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled) } } func BenchmarkSplice(b *testing.B) { testHookUninstaller.Do(uninstallTestHooks) b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") }) b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") }) } func benchSplice(b *testing.B, upNet, downNet string) { for i := 0; i <= 10; i++ { chunkSize := 1 << uint(i+10) tc := spliceTestCase{ upNet: upNet, downNet: downNet, chunkSize: chunkSize, } b.Run(strconv.Itoa(chunkSize), tc.bench) } } func (tc spliceTestCase) bench(b *testing.B) { // To benchmark the genericReadFrom code path, set this to false. useSplice := true clientUp, serverUp, err := spliceTestSocketPair(tc.upNet) if err != nil { b.Fatal(err) } defer serverUp.Close() cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N) if err != nil { b.Fatal(err) } defer cleanup() clientDown, serverDown, err := spliceTestSocketPair(tc.downNet) if err != nil { b.Fatal(err) } defer serverDown.Close() cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N) if err != nil { b.Fatal(err) } defer cleanup() b.SetBytes(int64(tc.chunkSize)) b.ResetTimer() if useSplice { _, err := io.Copy(serverDown, serverUp) if err != nil { b.Fatal(err) } } else { type onlyReader struct { io.Reader } _, err := io.Copy(serverDown, onlyReader{serverUp}) if err != nil { b.Fatal(err) } } } func spliceTestSocketPair(net string) (client, server Conn, err error) { ln, err := newLocalListener(net) if err != nil { return nil, nil, err } defer ln.Close() var cerr, serr error acceptDone := make(chan struct{}) go func() { server, serr = ln.Accept() acceptDone <- struct{}{} }() client, cerr = Dial(ln.Addr().Network(), ln.Addr().String()) <-acceptDone if cerr != nil { if server != nil { server.Close() } return nil, nil, cerr } if serr != nil { if client != nil { client.Close() } return nil, nil, serr } return client, server, nil } func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) { f, err := conn.(interface{ File() (*os.File, error) }).File() if err != nil { return nil, err } cmd := exec.Command(os.Args[0], os.Args[1:]...) cmd.Env = append(os.Environ(), []string{ "GO_NET_TEST_SPLICE=1", "GO_NET_TEST_SPLICE_OP=" + op, "GO_NET_TEST_SPLICE_CHUNK_SIZE=" + strconv.Itoa(chunkSize), "GO_NET_TEST_SPLICE_TOTAL_SIZE=" + strconv.Itoa(totalSize), }...) cmd.ExtraFiles = append(cmd.ExtraFiles, f) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { return nil, err } donec := make(chan struct{}) go func() { cmd.Wait() conn.Close() f.Close() close(donec) }() return func() { select { case <-donec: case <-time.After(5 * time.Second): log.Printf("killing splice client after 5 second shutdown timeout") cmd.Process.Kill() select { case <-donec: case <-time.After(5 * time.Second): log.Printf("splice client didn't die after 10 seconds") } } }, nil } func init() { if os.Getenv("GO_NET_TEST_SPLICE") == "" { return } defer os.Exit(0) f := os.NewFile(uintptr(3), "splice-test-conn") defer f.Close() conn, err := FileConn(f) if err != nil { log.Fatal(err) } var chunkSize int if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_CHUNK_SIZE")); err != nil { log.Fatal(err) } buf := make([]byte, chunkSize) var totalSize int if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_TOTAL_SIZE")); err != nil { log.Fatal(err) } var fn func([]byte) (int, error) switch op := os.Getenv("GO_NET_TEST_SPLICE_OP"); op { case "r": fn = conn.Read case "w": defer conn.Close() fn = conn.Write default: log.Fatalf("unknown op %q", op) } var n int for count := 0; count < totalSize; count += n { if count+chunkSize > totalSize { buf = buf[:totalSize-count] } var err error if n, err = fn(buf); err != nil { return } } }