Source file
src/io/multi_test.go
1
2
3
4
5 package io_test
6
7 import (
8 "bytes"
9 "crypto/sha1"
10 "errors"
11 "fmt"
12 . "io"
13 "runtime"
14 "strings"
15 "testing"
16 "time"
17 )
18
19 func TestMultiReader(t *testing.T) {
20 var mr Reader
21 var buf []byte
22 nread := 0
23 withFooBar := func(tests func()) {
24 r1 := strings.NewReader("foo ")
25 r2 := strings.NewReader("")
26 r3 := strings.NewReader("bar")
27 mr = MultiReader(r1, r2, r3)
28 buf = make([]byte, 20)
29 tests()
30 }
31 expectRead := func(size int, expected string, eerr error) {
32 nread++
33 n, gerr := mr.Read(buf[0:size])
34 if n != len(expected) {
35 t.Errorf("#%d, expected %d bytes; got %d",
36 nread, len(expected), n)
37 }
38 got := string(buf[0:n])
39 if got != expected {
40 t.Errorf("#%d, expected %q; got %q",
41 nread, expected, got)
42 }
43 if gerr != eerr {
44 t.Errorf("#%d, expected error %v; got %v",
45 nread, eerr, gerr)
46 }
47 buf = buf[n:]
48 }
49 withFooBar(func() {
50 expectRead(2, "fo", nil)
51 expectRead(5, "o ", nil)
52 expectRead(5, "bar", nil)
53 expectRead(5, "", EOF)
54 })
55 withFooBar(func() {
56 expectRead(4, "foo ", nil)
57 expectRead(1, "b", nil)
58 expectRead(3, "ar", nil)
59 expectRead(1, "", EOF)
60 })
61 withFooBar(func() {
62 expectRead(5, "foo ", nil)
63 })
64 }
65
66 func TestMultiWriter(t *testing.T) {
67 sink := new(bytes.Buffer)
68
69 testMultiWriter(t, struct {
70 Writer
71 fmt.Stringer
72 }{sink, sink})
73 }
74
75 func TestMultiWriter_String(t *testing.T) {
76 testMultiWriter(t, new(bytes.Buffer))
77 }
78
79
80
81 func TestMultiWriter_WriteStringSingleAlloc(t *testing.T) {
82 var sink1, sink2 bytes.Buffer
83 type simpleWriter struct {
84 Writer
85 }
86 mw := MultiWriter(simpleWriter{&sink1}, simpleWriter{&sink2})
87 allocs := int(testing.AllocsPerRun(1000, func() {
88 WriteString(mw, "foo")
89 }))
90 if allocs != 1 {
91 t.Errorf("num allocations = %d; want 1", allocs)
92 }
93 }
94
95 type writeStringChecker struct{ called bool }
96
97 func (c *writeStringChecker) WriteString(s string) (n int, err error) {
98 c.called = true
99 return len(s), nil
100 }
101
102 func (c *writeStringChecker) Write(p []byte) (n int, err error) {
103 return len(p), nil
104 }
105
106 func TestMultiWriter_StringCheckCall(t *testing.T) {
107 var c writeStringChecker
108 mw := MultiWriter(&c)
109 WriteString(mw, "foo")
110 if !c.called {
111 t.Error("did not see WriteString call to writeStringChecker")
112 }
113 }
114
115 func testMultiWriter(t *testing.T, sink interface {
116 Writer
117 fmt.Stringer
118 }) {
119 sha1 := sha1.New()
120 mw := MultiWriter(sha1, sink)
121
122 sourceString := "My input text."
123 source := strings.NewReader(sourceString)
124 written, err := Copy(mw, source)
125
126 if written != int64(len(sourceString)) {
127 t.Errorf("short write of %d, not %d", written, len(sourceString))
128 }
129
130 if err != nil {
131 t.Errorf("unexpected error: %v", err)
132 }
133
134 sha1hex := fmt.Sprintf("%x", sha1.Sum(nil))
135 if sha1hex != "01cb303fa8c30a64123067c5aa6284ba7ec2d31b" {
136 t.Error("incorrect sha1 value")
137 }
138
139 if sink.String() != sourceString {
140 t.Errorf("expected %q; got %q", sourceString, sink.String())
141 }
142 }
143
144
145 type writerFunc func(p []byte) (int, error)
146
147 func (f writerFunc) Write(p []byte) (int, error) {
148 return f(p)
149 }
150
151
152 func TestMultiWriterSingleChainFlatten(t *testing.T) {
153 pc := make([]uintptr, 1000)
154 n := runtime.Callers(0, pc)
155 var myDepth = callDepth(pc[:n])
156 var writeDepth int
157 var w Writer = MultiWriter(writerFunc(func(p []byte) (int, error) {
158 n := runtime.Callers(1, pc)
159 writeDepth += callDepth(pc[:n])
160 return 0, nil
161 }))
162
163 mw := w
164
165 for i := 0; i < 100; i++ {
166 mw = MultiWriter(w)
167 }
168
169 mw = MultiWriter(w, mw, w, mw)
170 mw.Write(nil)
171
172 if writeDepth != 4*(myDepth+2) {
173 t.Errorf("multiWriter did not flatten chained multiWriters: expected writeDepth %d, got %d",
174 4*(myDepth+2), writeDepth)
175 }
176 }
177
178 func TestMultiWriterError(t *testing.T) {
179 f1 := writerFunc(func(p []byte) (int, error) {
180 return len(p) / 2, ErrShortWrite
181 })
182 f2 := writerFunc(func(p []byte) (int, error) {
183 t.Errorf("MultiWriter called f2.Write")
184 return len(p), nil
185 })
186 w := MultiWriter(f1, f2)
187 n, err := w.Write(make([]byte, 100))
188 if n != 50 || err != ErrShortWrite {
189 t.Errorf("Write = %d, %v, want 50, ErrShortWrite", n, err)
190 }
191 }
192
193
194 func TestMultiReaderCopy(t *testing.T) {
195 slice := []Reader{strings.NewReader("hello world")}
196 r := MultiReader(slice...)
197 slice[0] = nil
198 data, err := ReadAll(r)
199 if err != nil || string(data) != "hello world" {
200 t.Errorf("ReadAll() = %q, %v, want %q, nil", data, err, "hello world")
201 }
202 }
203
204
205 func TestMultiWriterCopy(t *testing.T) {
206 var buf bytes.Buffer
207 slice := []Writer{&buf}
208 w := MultiWriter(slice...)
209 slice[0] = nil
210 n, err := w.Write([]byte("hello world"))
211 if err != nil || n != 11 {
212 t.Errorf("Write(`hello world`) = %d, %v, want 11, nil", n, err)
213 }
214 if buf.String() != "hello world" {
215 t.Errorf("buf.String() = %q, want %q", buf.String(), "hello world")
216 }
217 }
218
219
220 type readerFunc func(p []byte) (int, error)
221
222 func (f readerFunc) Read(p []byte) (int, error) {
223 return f(p)
224 }
225
226
227 func callDepth(callers []uintptr) (depth int) {
228 frames := runtime.CallersFrames(callers)
229 more := true
230 for more {
231 _, more = frames.Next()
232 depth++
233 }
234 return
235 }
236
237
238 func TestMultiReaderFlatten(t *testing.T) {
239 pc := make([]uintptr, 1000)
240 n := runtime.Callers(0, pc)
241 var myDepth = callDepth(pc[:n])
242 var readDepth int
243 var r Reader = MultiReader(readerFunc(func(p []byte) (int, error) {
244 n := runtime.Callers(1, pc)
245 readDepth = callDepth(pc[:n])
246 return 0, errors.New("irrelevant")
247 }))
248
249
250 for i := 0; i < 100; i++ {
251 r = MultiReader(r)
252 }
253
254 r.Read(nil)
255
256 if readDepth != myDepth+2 {
257 t.Errorf("multiReader did not flatten chained multiReaders: expected readDepth %d, got %d",
258 myDepth+2, readDepth)
259 }
260 }
261
262
263
264 type byteAndEOFReader byte
265
266 func (b byteAndEOFReader) Read(p []byte) (n int, err error) {
267 if len(p) == 0 {
268
269
270 panic("unexpected call")
271 }
272 p[0] = byte(b)
273 return 1, EOF
274 }
275
276
277 func TestMultiReaderSingleByteWithEOF(t *testing.T) {
278 got, err := ReadAll(LimitReader(MultiReader(byteAndEOFReader('a'), byteAndEOFReader('b')), 10))
279 if err != nil {
280 t.Fatal(err)
281 }
282 const want = "ab"
283 if string(got) != want {
284 t.Errorf("got %q; want %q", got, want)
285 }
286 }
287
288
289
290
291 func TestMultiReaderFinalEOF(t *testing.T) {
292 r := MultiReader(bytes.NewReader(nil), byteAndEOFReader('a'))
293 buf := make([]byte, 2)
294 n, err := r.Read(buf)
295 if n != 1 || err != EOF {
296 t.Errorf("got %v, %v; want 1, EOF", n, err)
297 }
298 }
299
300 func TestMultiReaderFreesExhaustedReaders(t *testing.T) {
301 var mr Reader
302 closed := make(chan struct{})
303
304
305
306 func() {
307 buf1 := bytes.NewReader([]byte("foo"))
308 buf2 := bytes.NewReader([]byte("bar"))
309 mr = MultiReader(buf1, buf2)
310 runtime.SetFinalizer(buf1, func(*bytes.Reader) {
311 close(closed)
312 })
313 }()
314
315 buf := make([]byte, 4)
316 if n, err := ReadFull(mr, buf); err != nil || string(buf) != "foob" {
317 t.Fatalf(`ReadFull = %d (%q), %v; want 3, "foo", nil`, n, buf[:n], err)
318 }
319
320 runtime.GC()
321 select {
322 case <-closed:
323 case <-time.After(5 * time.Second):
324 t.Fatal("timeout waiting for collection of buf1")
325 }
326
327 if n, err := ReadFull(mr, buf[:2]); err != nil || string(buf[:2]) != "ar" {
328 t.Fatalf(`ReadFull = %d (%q), %v; want 2, "ar", nil`, n, buf[:n], err)
329 }
330 }
331
332 func TestInterleavedMultiReader(t *testing.T) {
333 r1 := strings.NewReader("123")
334 r2 := strings.NewReader("45678")
335
336 mr1 := MultiReader(r1, r2)
337 mr2 := MultiReader(mr1)
338
339 buf := make([]byte, 4)
340
341
342
343 n, err := ReadFull(mr2, buf)
344 if got := string(buf[:n]); got != "1234" || err != nil {
345 t.Errorf(`ReadFull(mr2) = (%q, %v), want ("1234", nil)`, got, err)
346 }
347
348
349
350 n, err = ReadFull(mr1, buf)
351 if got := string(buf[:n]); got != "5678" || err != nil {
352 t.Errorf(`ReadFull(mr1) = (%q, %v), want ("5678", nil)`, got, err)
353 }
354 }
355
View as plain text