Source file
src/net/splice_test.go
1
2
3
4
5
6
7 package net
8
9 import (
10 "io"
11 "log"
12 "os"
13 "os/exec"
14 "strconv"
15 "sync"
16 "testing"
17 "time"
18 )
19
20 func TestSplice(t *testing.T) {
21 t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
22 if !testableNetwork("unixgram") {
23 t.Skip("skipping unix-to-tcp tests")
24 }
25 t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
26 t.Run("no-unixpacket", testSpliceNoUnixpacket)
27 t.Run("no-unixgram", testSpliceNoUnixgram)
28 }
29
30 func testSplice(t *testing.T, upNet, downNet string) {
31 t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
32 t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
33 t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
34 t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
35 t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
36 t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
37 t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
38 t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
39 }
40
41 type spliceTestCase struct {
42 upNet, downNet string
43
44 chunkSize, totalSize int
45 limitReadSize int
46 }
47
48 func (tc spliceTestCase) test(t *testing.T) {
49 clientUp, serverUp := spliceTestSocketPair(t, tc.upNet)
50 defer serverUp.Close()
51 cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize)
52 if err != nil {
53 t.Fatal(err)
54 }
55 defer cleanup()
56 clientDown, serverDown := spliceTestSocketPair(t, tc.downNet)
57 defer serverDown.Close()
58 cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize)
59 if err != nil {
60 t.Fatal(err)
61 }
62 defer cleanup()
63 var (
64 r io.Reader = serverUp
65 size = tc.totalSize
66 )
67 if tc.limitReadSize > 0 {
68 if tc.limitReadSize < size {
69 size = tc.limitReadSize
70 }
71
72 r = &io.LimitedReader{
73 N: int64(tc.limitReadSize),
74 R: serverUp,
75 }
76 defer serverUp.Close()
77 }
78 n, err := io.Copy(serverDown, r)
79 serverDown.Close()
80 if err != nil {
81 t.Fatal(err)
82 }
83 if want := int64(size); want != n {
84 t.Errorf("want %d bytes spliced, got %d", want, n)
85 }
86
87 if tc.limitReadSize > 0 {
88 wantN := 0
89 if tc.limitReadSize > size {
90 wantN = tc.limitReadSize - size
91 }
92
93 if n := r.(*io.LimitedReader).N; n != int64(wantN) {
94 t.Errorf("r.N = %d, want %d", n, wantN)
95 }
96 }
97 }
98
99 func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
100 clientUp, serverUp := spliceTestSocketPair(t, upNet)
101 defer clientUp.Close()
102 clientDown, serverDown := spliceTestSocketPair(t, downNet)
103 defer clientDown.Close()
104
105 serverUp.Close()
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123 msg := "bye"
124 go func() {
125 serverDown.(io.ReaderFrom).ReadFrom(serverUp)
126 io.WriteString(serverDown, msg)
127 serverDown.Close()
128 }()
129
130 buf := make([]byte, 3)
131 _, err := io.ReadFull(clientDown, buf)
132 if err != nil {
133 t.Errorf("clientDown: %v", err)
134 }
135 if string(buf) != msg {
136 t.Errorf("clientDown got %q, want %q", buf, msg)
137 }
138 }
139
140 func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
141 front := newLocalListener(t, upNet)
142 defer front.Close()
143 back := newLocalListener(t, downNet)
144 defer back.Close()
145
146 var wg sync.WaitGroup
147 wg.Add(2)
148
149 proxy := func() {
150 src, err := front.Accept()
151 if err != nil {
152 return
153 }
154 dst, err := Dial(downNet, back.Addr().String())
155 if err != nil {
156 return
157 }
158 defer dst.Close()
159 defer src.Close()
160 go func() {
161 io.Copy(src, dst)
162 wg.Done()
163 }()
164 go func() {
165 io.Copy(dst, src)
166 wg.Done()
167 }()
168 }
169
170 go proxy()
171
172 toFront, err := Dial(upNet, front.Addr().String())
173 if err != nil {
174 t.Fatal(err)
175 }
176
177 io.WriteString(toFront, "foo")
178 toFront.Close()
179
180 fromProxy, err := back.Accept()
181 if err != nil {
182 t.Fatal(err)
183 }
184 defer fromProxy.Close()
185
186 _, err = io.ReadAll(fromProxy)
187 if err != nil {
188 t.Fatal(err)
189 }
190
191 wg.Wait()
192 }
193
194 func testSpliceNoUnixpacket(t *testing.T) {
195 clientUp, serverUp := spliceTestSocketPair(t, "unixpacket")
196 defer clientUp.Close()
197 defer serverUp.Close()
198 clientDown, serverDown := spliceTestSocketPair(t, "tcp")
199 defer clientDown.Close()
200 defer serverDown.Close()
201
202
203
204
205
206
207
208
209 _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
210 if err != nil || handled != false {
211 t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
212 }
213 }
214
215 func testSpliceNoUnixgram(t *testing.T) {
216 addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t))
217 if err != nil {
218 t.Fatal(err)
219 }
220 defer os.Remove(addr.Name)
221 up, err := ListenUnixgram("unixgram", addr)
222 if err != nil {
223 t.Fatal(err)
224 }
225 defer up.Close()
226 clientDown, serverDown := spliceTestSocketPair(t, "tcp")
227 defer clientDown.Close()
228 defer serverDown.Close()
229
230 _, err, handled := splice(serverDown.(*TCPConn).fd, up)
231 if err != nil || handled != false {
232 t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
233 }
234 }
235
236 func BenchmarkSplice(b *testing.B) {
237 testHookUninstaller.Do(uninstallTestHooks)
238
239 b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
240 b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
241 }
242
243 func benchSplice(b *testing.B, upNet, downNet string) {
244 for i := 0; i <= 10; i++ {
245 chunkSize := 1 << uint(i+10)
246 tc := spliceTestCase{
247 upNet: upNet,
248 downNet: downNet,
249 chunkSize: chunkSize,
250 }
251
252 b.Run(strconv.Itoa(chunkSize), tc.bench)
253 }
254 }
255
256 func (tc spliceTestCase) bench(b *testing.B) {
257
258 useSplice := true
259
260 clientUp, serverUp := spliceTestSocketPair(b, tc.upNet)
261 defer serverUp.Close()
262
263 cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
264 if err != nil {
265 b.Fatal(err)
266 }
267 defer cleanup()
268
269 clientDown, serverDown := spliceTestSocketPair(b, tc.downNet)
270 defer serverDown.Close()
271
272 cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
273 if err != nil {
274 b.Fatal(err)
275 }
276 defer cleanup()
277
278 b.SetBytes(int64(tc.chunkSize))
279 b.ResetTimer()
280
281 if useSplice {
282 _, err := io.Copy(serverDown, serverUp)
283 if err != nil {
284 b.Fatal(err)
285 }
286 } else {
287 type onlyReader struct {
288 io.Reader
289 }
290 _, err := io.Copy(serverDown, onlyReader{serverUp})
291 if err != nil {
292 b.Fatal(err)
293 }
294 }
295 }
296
297 func spliceTestSocketPair(t testing.TB, net string) (client, server Conn) {
298 t.Helper()
299 ln := newLocalListener(t, net)
300 defer ln.Close()
301 var cerr, serr error
302 acceptDone := make(chan struct{})
303 go func() {
304 server, serr = ln.Accept()
305 acceptDone <- struct{}{}
306 }()
307 client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
308 <-acceptDone
309 if cerr != nil {
310 if server != nil {
311 server.Close()
312 }
313 t.Fatal(cerr)
314 }
315 if serr != nil {
316 if client != nil {
317 client.Close()
318 }
319 t.Fatal(serr)
320 }
321 return client, server
322 }
323
324 func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) {
325 f, err := conn.(interface{ File() (*os.File, error) }).File()
326 if err != nil {
327 return nil, err
328 }
329
330 cmd := exec.Command(os.Args[0], os.Args[1:]...)
331 cmd.Env = []string{
332 "GO_NET_TEST_SPLICE=1",
333 "GO_NET_TEST_SPLICE_OP=" + op,
334 "GO_NET_TEST_SPLICE_CHUNK_SIZE=" + strconv.Itoa(chunkSize),
335 "GO_NET_TEST_SPLICE_TOTAL_SIZE=" + strconv.Itoa(totalSize),
336 "TMPDIR=" + os.Getenv("TMPDIR"),
337 }
338 cmd.ExtraFiles = append(cmd.ExtraFiles, f)
339 cmd.Stdout = os.Stdout
340 cmd.Stderr = os.Stderr
341
342 if err := cmd.Start(); err != nil {
343 return nil, err
344 }
345
346 donec := make(chan struct{})
347 go func() {
348 cmd.Wait()
349 conn.Close()
350 f.Close()
351 close(donec)
352 }()
353
354 return func() {
355 select {
356 case <-donec:
357 case <-time.After(5 * time.Second):
358 log.Printf("killing splice client after 5 second shutdown timeout")
359 cmd.Process.Kill()
360 select {
361 case <-donec:
362 case <-time.After(5 * time.Second):
363 log.Printf("splice client didn't die after 10 seconds")
364 }
365 }
366 }, nil
367 }
368
369 func init() {
370 if os.Getenv("GO_NET_TEST_SPLICE") == "" {
371 return
372 }
373 defer os.Exit(0)
374
375 f := os.NewFile(uintptr(3), "splice-test-conn")
376 defer f.Close()
377
378 conn, err := FileConn(f)
379 if err != nil {
380 log.Fatal(err)
381 }
382
383 var chunkSize int
384 if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_CHUNK_SIZE")); err != nil {
385 log.Fatal(err)
386 }
387 buf := make([]byte, chunkSize)
388
389 var totalSize int
390 if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_TOTAL_SIZE")); err != nil {
391 log.Fatal(err)
392 }
393
394 var fn func([]byte) (int, error)
395 switch op := os.Getenv("GO_NET_TEST_SPLICE_OP"); op {
396 case "r":
397 fn = conn.Read
398 case "w":
399 defer conn.Close()
400
401 fn = conn.Write
402 default:
403 log.Fatalf("unknown op %q", op)
404 }
405
406 var n int
407 for count := 0; count < totalSize; count += n {
408 if count+chunkSize > totalSize {
409 buf = buf[:totalSize-count]
410 }
411
412 var err error
413 if n, err = fn(buf); err != nil {
414 return
415 }
416 }
417 }
418
View as plain text