Source file src/net/splice_test.go

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build linux
     6  
     7  package net
     8  
     9  import (
    10  	"io"
    11  	"log"
    12  	"os"
    13  	"os/exec"
    14  	"strconv"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  )
    19  
    20  func TestSplice(t *testing.T) {
    21  	t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
    22  	if !testableNetwork("unixgram") {
    23  		t.Skip("skipping unix-to-tcp tests")
    24  	}
    25  	t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
    26  	t.Run("no-unixpacket", testSpliceNoUnixpacket)
    27  	t.Run("no-unixgram", testSpliceNoUnixgram)
    28  }
    29  
    30  func testSplice(t *testing.T, upNet, downNet string) {
    31  	t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
    32  	t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
    33  	t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
    34  	t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
    35  	t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
    36  	t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
    37  	t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
    38  	t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
    39  }
    40  
    41  type spliceTestCase struct {
    42  	upNet, downNet string
    43  
    44  	chunkSize, totalSize int
    45  	limitReadSize        int
    46  }
    47  
    48  func (tc spliceTestCase) test(t *testing.T) {
    49  	clientUp, serverUp := spliceTestSocketPair(t, tc.upNet)
    50  	defer serverUp.Close()
    51  	cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize)
    52  	if err != nil {
    53  		t.Fatal(err)
    54  	}
    55  	defer cleanup()
    56  	clientDown, serverDown := spliceTestSocketPair(t, tc.downNet)
    57  	defer serverDown.Close()
    58  	cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize)
    59  	if err != nil {
    60  		t.Fatal(err)
    61  	}
    62  	defer cleanup()
    63  	var (
    64  		r    io.Reader = serverUp
    65  		size           = tc.totalSize
    66  	)
    67  	if tc.limitReadSize > 0 {
    68  		if tc.limitReadSize < size {
    69  			size = tc.limitReadSize
    70  		}
    71  
    72  		r = &io.LimitedReader{
    73  			N: int64(tc.limitReadSize),
    74  			R: serverUp,
    75  		}
    76  		defer serverUp.Close()
    77  	}
    78  	n, err := io.Copy(serverDown, r)
    79  	serverDown.Close()
    80  	if err != nil {
    81  		t.Fatal(err)
    82  	}
    83  	if want := int64(size); want != n {
    84  		t.Errorf("want %d bytes spliced, got %d", want, n)
    85  	}
    86  
    87  	if tc.limitReadSize > 0 {
    88  		wantN := 0
    89  		if tc.limitReadSize > size {
    90  			wantN = tc.limitReadSize - size
    91  		}
    92  
    93  		if n := r.(*io.LimitedReader).N; n != int64(wantN) {
    94  			t.Errorf("r.N = %d, want %d", n, wantN)
    95  		}
    96  	}
    97  }
    98  
    99  func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
   100  	clientUp, serverUp := spliceTestSocketPair(t, upNet)
   101  	defer clientUp.Close()
   102  	clientDown, serverDown := spliceTestSocketPair(t, downNet)
   103  	defer clientDown.Close()
   104  
   105  	serverUp.Close()
   106  
   107  	// We'd like to call net.splice here and check the handled return
   108  	// value, but we disable splice on old Linux kernels.
   109  	//
   110  	// In that case, poll.Splice and net.splice return a non-nil error
   111  	// and handled == false. We'd ideally like to see handled == true
   112  	// because the source reader is at EOF, but if we're running on an old
   113  	// kernel, and splice is disabled, we won't see EOF from net.splice,
   114  	// because we won't touch the reader at all.
   115  	//
   116  	// Trying to untangle the errors from net.splice and match them
   117  	// against the errors created by the poll package would be brittle,
   118  	// so this is a higher level test.
   119  	//
   120  	// The following ReadFrom should return immediately, regardless of
   121  	// whether splice is disabled or not. The other side should then
   122  	// get a goodbye signal. Test for the goodbye signal.
   123  	msg := "bye"
   124  	go func() {
   125  		serverDown.(io.ReaderFrom).ReadFrom(serverUp)
   126  		io.WriteString(serverDown, msg)
   127  		serverDown.Close()
   128  	}()
   129  
   130  	buf := make([]byte, 3)
   131  	_, err := io.ReadFull(clientDown, buf)
   132  	if err != nil {
   133  		t.Errorf("clientDown: %v", err)
   134  	}
   135  	if string(buf) != msg {
   136  		t.Errorf("clientDown got %q, want %q", buf, msg)
   137  	}
   138  }
   139  
   140  func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
   141  	front := newLocalListener(t, upNet)
   142  	defer front.Close()
   143  	back := newLocalListener(t, downNet)
   144  	defer back.Close()
   145  
   146  	var wg sync.WaitGroup
   147  	wg.Add(2)
   148  
   149  	proxy := func() {
   150  		src, err := front.Accept()
   151  		if err != nil {
   152  			return
   153  		}
   154  		dst, err := Dial(downNet, back.Addr().String())
   155  		if err != nil {
   156  			return
   157  		}
   158  		defer dst.Close()
   159  		defer src.Close()
   160  		go func() {
   161  			io.Copy(src, dst)
   162  			wg.Done()
   163  		}()
   164  		go func() {
   165  			io.Copy(dst, src)
   166  			wg.Done()
   167  		}()
   168  	}
   169  
   170  	go proxy()
   171  
   172  	toFront, err := Dial(upNet, front.Addr().String())
   173  	if err != nil {
   174  		t.Fatal(err)
   175  	}
   176  
   177  	io.WriteString(toFront, "foo")
   178  	toFront.Close()
   179  
   180  	fromProxy, err := back.Accept()
   181  	if err != nil {
   182  		t.Fatal(err)
   183  	}
   184  	defer fromProxy.Close()
   185  
   186  	_, err = io.ReadAll(fromProxy)
   187  	if err != nil {
   188  		t.Fatal(err)
   189  	}
   190  
   191  	wg.Wait()
   192  }
   193  
   194  func testSpliceNoUnixpacket(t *testing.T) {
   195  	clientUp, serverUp := spliceTestSocketPair(t, "unixpacket")
   196  	defer clientUp.Close()
   197  	defer serverUp.Close()
   198  	clientDown, serverDown := spliceTestSocketPair(t, "tcp")
   199  	defer clientDown.Close()
   200  	defer serverDown.Close()
   201  	// If splice called poll.Splice here, we'd get err == syscall.EINVAL
   202  	// and handled == false.  If poll.Splice gets an EINVAL on the first
   203  	// try, it assumes the kernel it's running on doesn't support splice
   204  	// for unix sockets and returns handled == false. This works for our
   205  	// purposes by somewhat of an accident, but is not entirely correct.
   206  	//
   207  	// What we want is err == nil and handled == false, i.e. we never
   208  	// called poll.Splice, because we know the unix socket's network.
   209  	_, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
   210  	if err != nil || handled != false {
   211  		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
   212  	}
   213  }
   214  
   215  func testSpliceNoUnixgram(t *testing.T) {
   216  	addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t))
   217  	if err != nil {
   218  		t.Fatal(err)
   219  	}
   220  	defer os.Remove(addr.Name)
   221  	up, err := ListenUnixgram("unixgram", addr)
   222  	if err != nil {
   223  		t.Fatal(err)
   224  	}
   225  	defer up.Close()
   226  	clientDown, serverDown := spliceTestSocketPair(t, "tcp")
   227  	defer clientDown.Close()
   228  	defer serverDown.Close()
   229  	// Analogous to testSpliceNoUnixpacket.
   230  	_, err, handled := splice(serverDown.(*TCPConn).fd, up)
   231  	if err != nil || handled != false {
   232  		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
   233  	}
   234  }
   235  
   236  func BenchmarkSplice(b *testing.B) {
   237  	testHookUninstaller.Do(uninstallTestHooks)
   238  
   239  	b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
   240  	b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
   241  }
   242  
   243  func benchSplice(b *testing.B, upNet, downNet string) {
   244  	for i := 0; i <= 10; i++ {
   245  		chunkSize := 1 << uint(i+10)
   246  		tc := spliceTestCase{
   247  			upNet:     upNet,
   248  			downNet:   downNet,
   249  			chunkSize: chunkSize,
   250  		}
   251  
   252  		b.Run(strconv.Itoa(chunkSize), tc.bench)
   253  	}
   254  }
   255  
   256  func (tc spliceTestCase) bench(b *testing.B) {
   257  	// To benchmark the genericReadFrom code path, set this to false.
   258  	useSplice := true
   259  
   260  	clientUp, serverUp := spliceTestSocketPair(b, tc.upNet)
   261  	defer serverUp.Close()
   262  
   263  	cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
   264  	if err != nil {
   265  		b.Fatal(err)
   266  	}
   267  	defer cleanup()
   268  
   269  	clientDown, serverDown := spliceTestSocketPair(b, tc.downNet)
   270  	defer serverDown.Close()
   271  
   272  	cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
   273  	if err != nil {
   274  		b.Fatal(err)
   275  	}
   276  	defer cleanup()
   277  
   278  	b.SetBytes(int64(tc.chunkSize))
   279  	b.ResetTimer()
   280  
   281  	if useSplice {
   282  		_, err := io.Copy(serverDown, serverUp)
   283  		if err != nil {
   284  			b.Fatal(err)
   285  		}
   286  	} else {
   287  		type onlyReader struct {
   288  			io.Reader
   289  		}
   290  		_, err := io.Copy(serverDown, onlyReader{serverUp})
   291  		if err != nil {
   292  			b.Fatal(err)
   293  		}
   294  	}
   295  }
   296  
   297  func spliceTestSocketPair(t testing.TB, net string) (client, server Conn) {
   298  	t.Helper()
   299  	ln := newLocalListener(t, net)
   300  	defer ln.Close()
   301  	var cerr, serr error
   302  	acceptDone := make(chan struct{})
   303  	go func() {
   304  		server, serr = ln.Accept()
   305  		acceptDone <- struct{}{}
   306  	}()
   307  	client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
   308  	<-acceptDone
   309  	if cerr != nil {
   310  		if server != nil {
   311  			server.Close()
   312  		}
   313  		t.Fatal(cerr)
   314  	}
   315  	if serr != nil {
   316  		if client != nil {
   317  			client.Close()
   318  		}
   319  		t.Fatal(serr)
   320  	}
   321  	return client, server
   322  }
   323  
   324  func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) {
   325  	f, err := conn.(interface{ File() (*os.File, error) }).File()
   326  	if err != nil {
   327  		return nil, err
   328  	}
   329  
   330  	cmd := exec.Command(os.Args[0], os.Args[1:]...)
   331  	cmd.Env = []string{
   332  		"GO_NET_TEST_SPLICE=1",
   333  		"GO_NET_TEST_SPLICE_OP=" + op,
   334  		"GO_NET_TEST_SPLICE_CHUNK_SIZE=" + strconv.Itoa(chunkSize),
   335  		"GO_NET_TEST_SPLICE_TOTAL_SIZE=" + strconv.Itoa(totalSize),
   336  		"TMPDIR=" + os.Getenv("TMPDIR"),
   337  	}
   338  	cmd.ExtraFiles = append(cmd.ExtraFiles, f)
   339  	cmd.Stdout = os.Stdout
   340  	cmd.Stderr = os.Stderr
   341  
   342  	if err := cmd.Start(); err != nil {
   343  		return nil, err
   344  	}
   345  
   346  	donec := make(chan struct{})
   347  	go func() {
   348  		cmd.Wait()
   349  		conn.Close()
   350  		f.Close()
   351  		close(donec)
   352  	}()
   353  
   354  	return func() {
   355  		select {
   356  		case <-donec:
   357  		case <-time.After(5 * time.Second):
   358  			log.Printf("killing splice client after 5 second shutdown timeout")
   359  			cmd.Process.Kill()
   360  			select {
   361  			case <-donec:
   362  			case <-time.After(5 * time.Second):
   363  				log.Printf("splice client didn't die after 10 seconds")
   364  			}
   365  		}
   366  	}, nil
   367  }
   368  
   369  func init() {
   370  	if os.Getenv("GO_NET_TEST_SPLICE") == "" {
   371  		return
   372  	}
   373  	defer os.Exit(0)
   374  
   375  	f := os.NewFile(uintptr(3), "splice-test-conn")
   376  	defer f.Close()
   377  
   378  	conn, err := FileConn(f)
   379  	if err != nil {
   380  		log.Fatal(err)
   381  	}
   382  
   383  	var chunkSize int
   384  	if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_CHUNK_SIZE")); err != nil {
   385  		log.Fatal(err)
   386  	}
   387  	buf := make([]byte, chunkSize)
   388  
   389  	var totalSize int
   390  	if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_TOTAL_SIZE")); err != nil {
   391  		log.Fatal(err)
   392  	}
   393  
   394  	var fn func([]byte) (int, error)
   395  	switch op := os.Getenv("GO_NET_TEST_SPLICE_OP"); op {
   396  	case "r":
   397  		fn = conn.Read
   398  	case "w":
   399  		defer conn.Close()
   400  
   401  		fn = conn.Write
   402  	default:
   403  		log.Fatalf("unknown op %q", op)
   404  	}
   405  
   406  	var n int
   407  	for count := 0; count < totalSize; count += n {
   408  		if count+chunkSize > totalSize {
   409  			buf = buf[:totalSize-count]
   410  		}
   411  
   412  		var err error
   413  		if n, err = fn(buf); err != nil {
   414  			return
   415  		}
   416  	}
   417  }
   418  

View as plain text