Source file src/os/readfrom_linux_test.go

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package os_test
     6  
     7  import (
     8  	"bytes"
     9  	"internal/poll"
    10  	"io"
    11  	"math/rand"
    12  	"os"
    13  	. "os"
    14  	"path/filepath"
    15  	"strconv"
    16  	"syscall"
    17  	"testing"
    18  	"time"
    19  )
    20  
    21  func TestCopyFileRange(t *testing.T) {
    22  	sizes := []int{
    23  		1,
    24  		42,
    25  		1025,
    26  		syscall.Getpagesize() + 1,
    27  		32769,
    28  	}
    29  	t.Run("Basic", func(t *testing.T) {
    30  		for _, size := range sizes {
    31  			t.Run(strconv.Itoa(size), func(t *testing.T) {
    32  				testCopyFileRange(t, int64(size), -1)
    33  			})
    34  		}
    35  	})
    36  	t.Run("Limited", func(t *testing.T) {
    37  		t.Run("OneLess", func(t *testing.T) {
    38  			for _, size := range sizes {
    39  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    40  					testCopyFileRange(t, int64(size), int64(size)-1)
    41  				})
    42  			}
    43  		})
    44  		t.Run("Half", func(t *testing.T) {
    45  			for _, size := range sizes {
    46  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    47  					testCopyFileRange(t, int64(size), int64(size)/2)
    48  				})
    49  			}
    50  		})
    51  		t.Run("More", func(t *testing.T) {
    52  			for _, size := range sizes {
    53  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    54  					testCopyFileRange(t, int64(size), int64(size)+7)
    55  				})
    56  			}
    57  		})
    58  	})
    59  	t.Run("DoesntTryInAppendMode", func(t *testing.T) {
    60  		dst, src, data, hook := newCopyFileRangeTest(t, 42)
    61  
    62  		dst2, err := OpenFile(dst.Name(), O_RDWR|O_APPEND, 0755)
    63  		if err != nil {
    64  			t.Fatal(err)
    65  		}
    66  		defer dst2.Close()
    67  
    68  		if _, err := io.Copy(dst2, src); err != nil {
    69  			t.Fatal(err)
    70  		}
    71  		if hook.called {
    72  			t.Fatal("called poll.CopyFileRange for destination in O_APPEND mode")
    73  		}
    74  		mustSeekStart(t, dst2)
    75  		mustContainData(t, dst2, data) // through traditional means
    76  	})
    77  	t.Run("NotRegular", func(t *testing.T) {
    78  		t.Run("BothPipes", func(t *testing.T) {
    79  			hook := hookCopyFileRange(t)
    80  
    81  			pr1, pw1, err := Pipe()
    82  			if err != nil {
    83  				t.Fatal(err)
    84  			}
    85  			defer pr1.Close()
    86  			defer pw1.Close()
    87  
    88  			pr2, pw2, err := Pipe()
    89  			if err != nil {
    90  				t.Fatal(err)
    91  			}
    92  			defer pr2.Close()
    93  			defer pw2.Close()
    94  
    95  			// The pipe is empty, and PIPE_BUF is large enough
    96  			// for this, by (POSIX) definition, so there is no
    97  			// need for an additional goroutine.
    98  			data := []byte("hello")
    99  			if _, err := pw1.Write(data); err != nil {
   100  				t.Fatal(err)
   101  			}
   102  			pw1.Close()
   103  
   104  			n, err := io.Copy(pw2, pr1)
   105  			if err != nil {
   106  				t.Fatal(err)
   107  			}
   108  			if n != int64(len(data)) {
   109  				t.Fatalf("transferred %d, want %d", n, len(data))
   110  			}
   111  			if !hook.called {
   112  				t.Fatalf("should have called poll.CopyFileRange")
   113  			}
   114  			pw2.Close()
   115  			mustContainData(t, pr2, data)
   116  		})
   117  		t.Run("DstPipe", func(t *testing.T) {
   118  			dst, src, data, hook := newCopyFileRangeTest(t, 255)
   119  			dst.Close()
   120  
   121  			pr, pw, err := Pipe()
   122  			if err != nil {
   123  				t.Fatal(err)
   124  			}
   125  			defer pr.Close()
   126  			defer pw.Close()
   127  
   128  			n, err := io.Copy(pw, src)
   129  			if err != nil {
   130  				t.Fatal(err)
   131  			}
   132  			if n != int64(len(data)) {
   133  				t.Fatalf("transferred %d, want %d", n, len(data))
   134  			}
   135  			if !hook.called {
   136  				t.Fatalf("should have called poll.CopyFileRange")
   137  			}
   138  			pw.Close()
   139  			mustContainData(t, pr, data)
   140  		})
   141  		t.Run("SrcPipe", func(t *testing.T) {
   142  			dst, src, data, hook := newCopyFileRangeTest(t, 255)
   143  			src.Close()
   144  
   145  			pr, pw, err := Pipe()
   146  			if err != nil {
   147  				t.Fatal(err)
   148  			}
   149  			defer pr.Close()
   150  			defer pw.Close()
   151  
   152  			// The pipe is empty, and PIPE_BUF is large enough
   153  			// for this, by (POSIX) definition, so there is no
   154  			// need for an additional goroutine.
   155  			if _, err := pw.Write(data); err != nil {
   156  				t.Fatal(err)
   157  			}
   158  			pw.Close()
   159  
   160  			n, err := io.Copy(dst, pr)
   161  			if err != nil {
   162  				t.Fatal(err)
   163  			}
   164  			if n != int64(len(data)) {
   165  				t.Fatalf("transferred %d, want %d", n, len(data))
   166  			}
   167  			if !hook.called {
   168  				t.Fatalf("should have called poll.CopyFileRange")
   169  			}
   170  			mustSeekStart(t, dst)
   171  			mustContainData(t, dst, data)
   172  		})
   173  	})
   174  	t.Run("Nil", func(t *testing.T) {
   175  		var nilFile *File
   176  		anyFile, err := os.CreateTemp("", "")
   177  		if err != nil {
   178  			t.Fatal(err)
   179  		}
   180  		defer Remove(anyFile.Name())
   181  		defer anyFile.Close()
   182  
   183  		if _, err := io.Copy(nilFile, nilFile); err != ErrInvalid {
   184  			t.Errorf("io.Copy(nilFile, nilFile) = %v, want %v", err, ErrInvalid)
   185  		}
   186  		if _, err := io.Copy(anyFile, nilFile); err != ErrInvalid {
   187  			t.Errorf("io.Copy(anyFile, nilFile) = %v, want %v", err, ErrInvalid)
   188  		}
   189  		if _, err := io.Copy(nilFile, anyFile); err != ErrInvalid {
   190  			t.Errorf("io.Copy(nilFile, anyFile) = %v, want %v", err, ErrInvalid)
   191  		}
   192  
   193  		if _, err := nilFile.ReadFrom(nilFile); err != ErrInvalid {
   194  			t.Errorf("nilFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
   195  		}
   196  		if _, err := anyFile.ReadFrom(nilFile); err != ErrInvalid {
   197  			t.Errorf("anyFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
   198  		}
   199  		if _, err := nilFile.ReadFrom(anyFile); err != ErrInvalid {
   200  			t.Errorf("nilFile.ReadFrom(anyFile) = %v, want %v", err, ErrInvalid)
   201  		}
   202  	})
   203  }
   204  
   205  func testCopyFileRange(t *testing.T, size int64, limit int64) {
   206  	dst, src, data, hook := newCopyFileRangeTest(t, size)
   207  
   208  	// If we have a limit, wrap the reader.
   209  	var (
   210  		realsrc io.Reader
   211  		lr      *io.LimitedReader
   212  	)
   213  	if limit >= 0 {
   214  		lr = &io.LimitedReader{N: limit, R: src}
   215  		realsrc = lr
   216  		if limit < int64(len(data)) {
   217  			data = data[:limit]
   218  		}
   219  	} else {
   220  		realsrc = src
   221  	}
   222  
   223  	// Now call ReadFrom (through io.Copy), which will hopefully call
   224  	// poll.CopyFileRange.
   225  	n, err := io.Copy(dst, realsrc)
   226  	if err != nil {
   227  		t.Fatal(err)
   228  	}
   229  
   230  	// If we didn't have a limit, we should have called poll.CopyFileRange
   231  	// with the right file descriptor arguments.
   232  	if limit > 0 && !hook.called {
   233  		t.Fatal("never called poll.CopyFileRange")
   234  	}
   235  	if hook.called && hook.dstfd != int(dst.Fd()) {
   236  		t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
   237  	}
   238  	if hook.called && hook.srcfd != int(src.Fd()) {
   239  		t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
   240  	}
   241  
   242  	// Check that the offsets after the transfer make sense, that the size
   243  	// of the transfer was reported correctly, and that the destination
   244  	// file contains exactly the bytes we expect it to contain.
   245  	dstoff, err := dst.Seek(0, io.SeekCurrent)
   246  	if err != nil {
   247  		t.Fatal(err)
   248  	}
   249  	srcoff, err := src.Seek(0, io.SeekCurrent)
   250  	if err != nil {
   251  		t.Fatal(err)
   252  	}
   253  	if dstoff != srcoff {
   254  		t.Errorf("offsets differ: dstoff = %d, srcoff = %d", dstoff, srcoff)
   255  	}
   256  	if dstoff != int64(len(data)) {
   257  		t.Errorf("dstoff = %d, want %d", dstoff, len(data))
   258  	}
   259  	if n != int64(len(data)) {
   260  		t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
   261  	}
   262  	mustSeekStart(t, dst)
   263  	mustContainData(t, dst, data)
   264  
   265  	// If we had a limit, check that it was updated.
   266  	if lr != nil {
   267  		if want := limit - n; lr.N != want {
   268  			t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
   269  		}
   270  	}
   271  }
   272  
   273  // newCopyFileRangeTest initializes a new test for copy_file_range.
   274  //
   275  // It creates source and destination files, and populates the source file
   276  // with random data of the specified size. It also hooks package os' call
   277  // to poll.CopyFileRange and returns the hook so it can be inspected.
   278  func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileRangeHook) {
   279  	t.Helper()
   280  
   281  	hook = hookCopyFileRange(t)
   282  	tmp := t.TempDir()
   283  
   284  	src, err := Create(filepath.Join(tmp, "src"))
   285  	if err != nil {
   286  		t.Fatal(err)
   287  	}
   288  	t.Cleanup(func() { src.Close() })
   289  
   290  	dst, err = Create(filepath.Join(tmp, "dst"))
   291  	if err != nil {
   292  		t.Fatal(err)
   293  	}
   294  	t.Cleanup(func() { dst.Close() })
   295  
   296  	// Populate the source file with data, then rewind it, so it can be
   297  	// consumed by copy_file_range(2).
   298  	prng := rand.New(rand.NewSource(time.Now().Unix()))
   299  	data = make([]byte, size)
   300  	prng.Read(data)
   301  	if _, err := src.Write(data); err != nil {
   302  		t.Fatal(err)
   303  	}
   304  	if _, err := src.Seek(0, io.SeekStart); err != nil {
   305  		t.Fatal(err)
   306  	}
   307  
   308  	return dst, src, data, hook
   309  }
   310  
   311  // mustContainData ensures that the specified file contains exactly the
   312  // specified data.
   313  func mustContainData(t *testing.T, f *File, data []byte) {
   314  	t.Helper()
   315  
   316  	got := make([]byte, len(data))
   317  	if _, err := io.ReadFull(f, got); err != nil {
   318  		t.Fatal(err)
   319  	}
   320  	if !bytes.Equal(got, data) {
   321  		t.Fatalf("didn't get the same data back from %s", f.Name())
   322  	}
   323  	if _, err := f.Read(make([]byte, 1)); err != io.EOF {
   324  		t.Fatalf("not at EOF")
   325  	}
   326  }
   327  
   328  func mustSeekStart(t *testing.T, f *File) {
   329  	if _, err := f.Seek(0, io.SeekStart); err != nil {
   330  		t.Fatal(err)
   331  	}
   332  }
   333  
   334  func hookCopyFileRange(t *testing.T) *copyFileRangeHook {
   335  	h := new(copyFileRangeHook)
   336  	h.install()
   337  	t.Cleanup(h.uninstall)
   338  	return h
   339  }
   340  
   341  type copyFileRangeHook struct {
   342  	called bool
   343  	dstfd  int
   344  	srcfd  int
   345  	remain int64
   346  
   347  	original func(dst, src *poll.FD, remain int64) (int64, bool, error)
   348  }
   349  
   350  func (h *copyFileRangeHook) install() {
   351  	h.original = *PollCopyFileRangeP
   352  	*PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
   353  		h.called = true
   354  		h.dstfd = dst.Sysfd
   355  		h.srcfd = src.Sysfd
   356  		h.remain = remain
   357  		return h.original(dst, src, remain)
   358  	}
   359  }
   360  
   361  func (h *copyFileRangeHook) uninstall() {
   362  	*PollCopyFileRangeP = h.original
   363  }
   364  
   365  // On some kernels copy_file_range fails on files in /proc.
   366  func TestProcCopy(t *testing.T) {
   367  	const cmdlineFile = "/proc/self/cmdline"
   368  	cmdline, err := os.ReadFile(cmdlineFile)
   369  	if err != nil {
   370  		t.Skipf("can't read /proc file: %v", err)
   371  	}
   372  	in, err := os.Open(cmdlineFile)
   373  	if err != nil {
   374  		t.Fatal(err)
   375  	}
   376  	defer in.Close()
   377  	outFile := filepath.Join(t.TempDir(), "cmdline")
   378  	out, err := os.Create(outFile)
   379  	if err != nil {
   380  		t.Fatal(err)
   381  	}
   382  	if _, err := io.Copy(out, in); err != nil {
   383  		t.Fatal(err)
   384  	}
   385  	if err := out.Close(); err != nil {
   386  		t.Fatal(err)
   387  	}
   388  	copy, err := os.ReadFile(outFile)
   389  	if err != nil {
   390  		t.Fatal(err)
   391  	}
   392  	if !bytes.Equal(cmdline, copy) {
   393  		t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline)
   394  	}
   395  }
   396  

View as plain text