1
2
3
4
5 package base64
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "io"
12 "reflect"
13 "runtime/debug"
14 "strings"
15 "testing"
16 "time"
17 )
18
19 type testpair struct {
20 decoded, encoded string
21 }
22
23 var pairs = []testpair{
24
25 {"\x14\xfb\x9c\x03\xd9\x7e", "FPucA9l+"},
26 {"\x14\xfb\x9c\x03\xd9", "FPucA9k="},
27 {"\x14\xfb\x9c\x03", "FPucAw=="},
28
29
30 {"", ""},
31 {"f", "Zg=="},
32 {"fo", "Zm8="},
33 {"foo", "Zm9v"},
34 {"foob", "Zm9vYg=="},
35 {"fooba", "Zm9vYmE="},
36 {"foobar", "Zm9vYmFy"},
37
38
39 {"sure.", "c3VyZS4="},
40 {"sure", "c3VyZQ=="},
41 {"sur", "c3Vy"},
42 {"su", "c3U="},
43 {"leasure.", "bGVhc3VyZS4="},
44 {"easure.", "ZWFzdXJlLg=="},
45 {"asure.", "YXN1cmUu"},
46 {"sure.", "c3VyZS4="},
47 }
48
49
50 func stdRef(ref string) string {
51 return ref
52 }
53
54
55 func urlRef(ref string) string {
56 ref = strings.ReplaceAll(ref, "+", "-")
57 ref = strings.ReplaceAll(ref, "/", "_")
58 return ref
59 }
60
61
62 func rawRef(ref string) string {
63 return strings.TrimRight(ref, "=")
64 }
65
66
67 func rawURLRef(ref string) string {
68 return rawRef(urlRef(ref))
69 }
70
71
72 var funnyEncoding = NewEncoding(encodeStd).WithPadding(rune('@'))
73
74 func funnyRef(ref string) string {
75 return strings.ReplaceAll(ref, "=", "@")
76 }
77
78 type encodingTest struct {
79 enc *Encoding
80 conv func(string) string
81 }
82
83 var encodingTests = []encodingTest{
84 {StdEncoding, stdRef},
85 {URLEncoding, urlRef},
86 {RawStdEncoding, rawRef},
87 {RawURLEncoding, rawURLRef},
88 {funnyEncoding, funnyRef},
89 {StdEncoding.Strict(), stdRef},
90 {URLEncoding.Strict(), urlRef},
91 {RawStdEncoding.Strict(), rawRef},
92 {RawURLEncoding.Strict(), rawURLRef},
93 {funnyEncoding.Strict(), funnyRef},
94 }
95
96 var bigtest = testpair{
97 "Twas brillig, and the slithy toves",
98 "VHdhcyBicmlsbGlnLCBhbmQgdGhlIHNsaXRoeSB0b3Zlcw==",
99 }
100
101 func testEqual(t *testing.T, msg string, args ...any) bool {
102 t.Helper()
103 if args[len(args)-2] != args[len(args)-1] {
104 t.Errorf(msg, args...)
105 return false
106 }
107 return true
108 }
109
110 func TestEncode(t *testing.T) {
111 for _, p := range pairs {
112 for _, tt := range encodingTests {
113 got := tt.enc.EncodeToString([]byte(p.decoded))
114 testEqual(t, "Encode(%q) = %q, want %q", p.decoded,
115 got, tt.conv(p.encoded))
116 }
117 }
118 }
119
120 func TestEncoder(t *testing.T) {
121 for _, p := range pairs {
122 bb := &bytes.Buffer{}
123 encoder := NewEncoder(StdEncoding, bb)
124 encoder.Write([]byte(p.decoded))
125 encoder.Close()
126 testEqual(t, "Encode(%q) = %q, want %q", p.decoded, bb.String(), p.encoded)
127 }
128 }
129
130 func TestEncoderBuffering(t *testing.T) {
131 input := []byte(bigtest.decoded)
132 for bs := 1; bs <= 12; bs++ {
133 bb := &bytes.Buffer{}
134 encoder := NewEncoder(StdEncoding, bb)
135 for pos := 0; pos < len(input); pos += bs {
136 end := pos + bs
137 if end > len(input) {
138 end = len(input)
139 }
140 n, err := encoder.Write(input[pos:end])
141 testEqual(t, "Write(%q) gave error %v, want %v", input[pos:end], err, error(nil))
142 testEqual(t, "Write(%q) gave length %v, want %v", input[pos:end], n, end-pos)
143 }
144 err := encoder.Close()
145 testEqual(t, "Close gave error %v, want %v", err, error(nil))
146 testEqual(t, "Encoding/%d of %q = %q, want %q", bs, bigtest.decoded, bb.String(), bigtest.encoded)
147 }
148 }
149
150 func TestDecode(t *testing.T) {
151 for _, p := range pairs {
152 for _, tt := range encodingTests {
153 encoded := tt.conv(p.encoded)
154 dbuf := make([]byte, tt.enc.DecodedLen(len(encoded)))
155 count, err := tt.enc.Decode(dbuf, []byte(encoded))
156 testEqual(t, "Decode(%q) = error %v, want %v", encoded, err, error(nil))
157 testEqual(t, "Decode(%q) = length %v, want %v", encoded, count, len(p.decoded))
158 testEqual(t, "Decode(%q) = %q, want %q", encoded, string(dbuf[0:count]), p.decoded)
159
160 dbuf, err = tt.enc.DecodeString(encoded)
161 testEqual(t, "DecodeString(%q) = error %v, want %v", encoded, err, error(nil))
162 testEqual(t, "DecodeString(%q) = %q, want %q", encoded, string(dbuf), p.decoded)
163 }
164 }
165 }
166
167 func TestDecoder(t *testing.T) {
168 for _, p := range pairs {
169 decoder := NewDecoder(StdEncoding, strings.NewReader(p.encoded))
170 dbuf := make([]byte, StdEncoding.DecodedLen(len(p.encoded)))
171 count, err := decoder.Read(dbuf)
172 if err != nil && err != io.EOF {
173 t.Fatal("Read failed", err)
174 }
175 testEqual(t, "Read from %q = length %v, want %v", p.encoded, count, len(p.decoded))
176 testEqual(t, "Decoding of %q = %q, want %q", p.encoded, string(dbuf[0:count]), p.decoded)
177 if err != io.EOF {
178 _, err = decoder.Read(dbuf)
179 }
180 testEqual(t, "Read from %q = %v, want %v", p.encoded, err, io.EOF)
181 }
182 }
183
184 func TestDecoderBuffering(t *testing.T) {
185 for bs := 1; bs <= 12; bs++ {
186 decoder := NewDecoder(StdEncoding, strings.NewReader(bigtest.encoded))
187 buf := make([]byte, len(bigtest.decoded)+12)
188 var total int
189 var n int
190 var err error
191 for total = 0; total < len(bigtest.decoded) && err == nil; {
192 n, err = decoder.Read(buf[total : total+bs])
193 total += n
194 }
195 if err != nil && err != io.EOF {
196 t.Errorf("Read from %q at pos %d = %d, unexpected error %v", bigtest.encoded, total, n, err)
197 }
198 testEqual(t, "Decoding/%d of %q = %q, want %q", bs, bigtest.encoded, string(buf[0:total]), bigtest.decoded)
199 }
200 }
201
202 func TestDecodeCorrupt(t *testing.T) {
203 testCases := []struct {
204 input string
205 offset int
206 }{
207 {"", -1},
208 {"\n", -1},
209 {"AAA=\n", -1},
210 {"AAAA\n", -1},
211 {"!!!!", 0},
212 {"====", 0},
213 {"x===", 1},
214 {"=AAA", 0},
215 {"A=AA", 1},
216 {"AA=A", 2},
217 {"AA==A", 4},
218 {"AAA=AAAA", 4},
219 {"AAAAA", 4},
220 {"AAAAAA", 4},
221 {"A=", 1},
222 {"A==", 1},
223 {"AA=", 3},
224 {"AA==", -1},
225 {"AAA=", -1},
226 {"AAAA", -1},
227 {"AAAAAA=", 7},
228 {"YWJjZA=====", 8},
229 {"A!\n", 1},
230 {"A=\n", 1},
231 }
232 for _, tc := range testCases {
233 dbuf := make([]byte, StdEncoding.DecodedLen(len(tc.input)))
234 _, err := StdEncoding.Decode(dbuf, []byte(tc.input))
235 if tc.offset == -1 {
236 if err != nil {
237 t.Error("Decoder wrongly detected corruption in", tc.input)
238 }
239 continue
240 }
241 switch err := err.(type) {
242 case CorruptInputError:
243 testEqual(t, "Corruption in %q at offset %v, want %v", tc.input, int(err), tc.offset)
244 default:
245 t.Error("Decoder failed to detect corruption in", tc)
246 }
247 }
248 }
249
250 func TestDecodeBounds(t *testing.T) {
251 var buf [32]byte
252 s := StdEncoding.EncodeToString(buf[:])
253 defer func() {
254 if err := recover(); err != nil {
255 t.Fatalf("Decode panicked unexpectedly: %v\n%s", err, debug.Stack())
256 }
257 }()
258 n, err := StdEncoding.Decode(buf[:], []byte(s))
259 if n != len(buf) || err != nil {
260 t.Fatalf("StdEncoding.Decode = %d, %v, want %d, nil", n, err, len(buf))
261 }
262 }
263
264 func TestEncodedLen(t *testing.T) {
265 for _, tt := range []struct {
266 enc *Encoding
267 n int
268 want int
269 }{
270 {RawStdEncoding, 0, 0},
271 {RawStdEncoding, 1, 2},
272 {RawStdEncoding, 2, 3},
273 {RawStdEncoding, 3, 4},
274 {RawStdEncoding, 7, 10},
275 {StdEncoding, 0, 0},
276 {StdEncoding, 1, 4},
277 {StdEncoding, 2, 4},
278 {StdEncoding, 3, 4},
279 {StdEncoding, 4, 8},
280 {StdEncoding, 7, 12},
281 } {
282 if got := tt.enc.EncodedLen(tt.n); got != tt.want {
283 t.Errorf("EncodedLen(%d): got %d, want %d", tt.n, got, tt.want)
284 }
285 }
286 }
287
288 func TestDecodedLen(t *testing.T) {
289 for _, tt := range []struct {
290 enc *Encoding
291 n int
292 want int
293 }{
294 {RawStdEncoding, 0, 0},
295 {RawStdEncoding, 2, 1},
296 {RawStdEncoding, 3, 2},
297 {RawStdEncoding, 4, 3},
298 {RawStdEncoding, 10, 7},
299 {StdEncoding, 0, 0},
300 {StdEncoding, 4, 3},
301 {StdEncoding, 8, 6},
302 } {
303 if got := tt.enc.DecodedLen(tt.n); got != tt.want {
304 t.Errorf("DecodedLen(%d): got %d, want %d", tt.n, got, tt.want)
305 }
306 }
307 }
308
309 func TestBig(t *testing.T) {
310 n := 3*1000 + 1
311 raw := make([]byte, n)
312 const alpha = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
313 for i := 0; i < n; i++ {
314 raw[i] = alpha[i%len(alpha)]
315 }
316 encoded := new(bytes.Buffer)
317 w := NewEncoder(StdEncoding, encoded)
318 nn, err := w.Write(raw)
319 if nn != n || err != nil {
320 t.Fatalf("Encoder.Write(raw) = %d, %v want %d, nil", nn, err, n)
321 }
322 err = w.Close()
323 if err != nil {
324 t.Fatalf("Encoder.Close() = %v want nil", err)
325 }
326 decoded, err := io.ReadAll(NewDecoder(StdEncoding, encoded))
327 if err != nil {
328 t.Fatalf("io.ReadAll(NewDecoder(...)): %v", err)
329 }
330
331 if !bytes.Equal(raw, decoded) {
332 var i int
333 for i = 0; i < len(decoded) && i < len(raw); i++ {
334 if decoded[i] != raw[i] {
335 break
336 }
337 }
338 t.Errorf("Decode(Encode(%d-byte string)) failed at offset %d", n, i)
339 }
340 }
341
342 func TestNewLineCharacters(t *testing.T) {
343
344 const expected = "sure"
345 examples := []string{
346 "c3VyZQ==",
347 "c3VyZQ==\r",
348 "c3VyZQ==\n",
349 "c3VyZQ==\r\n",
350 "c3VyZ\r\nQ==",
351 "c3V\ryZ\nQ==",
352 "c3V\nyZ\rQ==",
353 "c3VyZ\nQ==",
354 "c3VyZQ\n==",
355 "c3VyZQ=\n=",
356 "c3VyZQ=\r\n\r\n=",
357 }
358 for _, e := range examples {
359 buf, err := StdEncoding.DecodeString(e)
360 if err != nil {
361 t.Errorf("Decode(%q) failed: %v", e, err)
362 continue
363 }
364 if s := string(buf); s != expected {
365 t.Errorf("Decode(%q) = %q, want %q", e, s, expected)
366 }
367 }
368 }
369
370 type nextRead struct {
371 n int
372 err error
373 }
374
375
376
377 type faultInjectReader struct {
378 source string
379 nextc <-chan nextRead
380 }
381
382 func (r *faultInjectReader) Read(p []byte) (int, error) {
383 nr := <-r.nextc
384 if len(p) > nr.n {
385 p = p[:nr.n]
386 }
387 n := copy(p, r.source)
388 r.source = r.source[n:]
389 return n, nr.err
390 }
391
392
393 func TestDecoderIssue3577(t *testing.T) {
394 next := make(chan nextRead, 10)
395 wantErr := errors.New("my error")
396 next <- nextRead{5, nil}
397 next <- nextRead{10, wantErr}
398 next <- nextRead{0, wantErr}
399 d := NewDecoder(StdEncoding, &faultInjectReader{
400 source: "VHdhcyBicmlsbGlnLCBhbmQgdGhlIHNsaXRoeSB0b3Zlcw==",
401 nextc: next,
402 })
403 errc := make(chan error, 1)
404 go func() {
405 _, err := io.ReadAll(d)
406 errc <- err
407 }()
408 select {
409 case err := <-errc:
410 if err != wantErr {
411 t.Errorf("got error %v; want %v", err, wantErr)
412 }
413 case <-time.After(5 * time.Second):
414 t.Errorf("timeout; Decoder blocked without returning an error")
415 }
416 }
417
418 func TestDecoderIssue4779(t *testing.T) {
419 encoded := `CP/EAT8AAAEF
420 AQEBAQEBAAAAAAAAAAMAAQIEBQYHCAkKCwEAAQUBAQEBAQEAAAAAAAAAAQACAwQFBgcICQoLEAAB
421 BAEDAgQCBQcGCAUDDDMBAAIRAwQhEjEFQVFhEyJxgTIGFJGhsUIjJBVSwWIzNHKC0UMHJZJT8OHx
422 Y3M1FqKygyZEk1RkRcKjdDYX0lXiZfKzhMPTdePzRieUpIW0lcTU5PSltcXV5fVWZnaGlqa2xtbm
423 9jdHV2d3h5ent8fX5/cRAAICAQIEBAMEBQYHBwYFNQEAAhEDITESBEFRYXEiEwUygZEUobFCI8FS
424 0fAzJGLhcoKSQ1MVY3M08SUGFqKygwcmNcLSRJNUoxdkRVU2dGXi8rOEw9N14/NGlKSFtJXE1OT0
425 pbXF1eX1VmZ2hpamtsbW5vYnN0dXZ3eHl6e3x//aAAwDAQACEQMRAD8A9VSSSSUpJJJJSkkkJ+Tj
426 1kiy1jCJJDnAcCTykpKkuQ6p/jN6FgmxlNduXawwAzaGH+V6jn/R/wCt71zdn+N/qL3kVYFNYB4N
427 ji6PDVjWpKp9TSXnvTf8bFNjg3qOEa2n6VlLpj/rT/pf567DpX1i6L1hs9Py67X8mqdtg/rUWbbf
428 +gkp0kkkklKSSSSUpJJJJT//0PVUkkklKVLq3WMDpGI7KzrNjADtYNXvI/Mqr/Pd/q9W3vaxjnvM
429 NaCXE9gNSvGPrf8AWS3qmba5jjsJhoB0DAf0NDf6sevf+/lf8Hj0JJATfWT6/dV6oXU1uOLQeKKn
430 EQP+Hubtfe/+R7Mf/g7f5xcocp++Z11JMCJPgFBxOg7/AOuqDx8I/ikpkXkmSdU8mJIJA/O8EMAy
431 j+mSARB/17pKVXYWHXjsj7yIex0PadzXMO1zT5KHoNA3HT8ietoGhgjsfA+CSnvvqh/jJtqsrwOv
432 2b6NGNzXfTYexzJ+nU7/ALkf4P8Awv6P9KvTQQ4AgyDqCF85Pho3CTB7eHwXoH+LT65uZbX9X+o2
433 bqbPb06551Y4
434 `
435 encodedShort := strings.ReplaceAll(encoded, "\n", "")
436
437 dec := NewDecoder(StdEncoding, strings.NewReader(encoded))
438 res1, err := io.ReadAll(dec)
439 if err != nil {
440 t.Errorf("ReadAll failed: %v", err)
441 }
442
443 dec = NewDecoder(StdEncoding, strings.NewReader(encodedShort))
444 var res2 []byte
445 res2, err = io.ReadAll(dec)
446 if err != nil {
447 t.Errorf("ReadAll failed: %v", err)
448 }
449
450 if !bytes.Equal(res1, res2) {
451 t.Error("Decoded results not equal")
452 }
453 }
454
455 func TestDecoderIssue7733(t *testing.T) {
456 s, err := StdEncoding.DecodeString("YWJjZA=====")
457 want := CorruptInputError(8)
458 if !reflect.DeepEqual(want, err) {
459 t.Errorf("Error = %v; want CorruptInputError(8)", err)
460 }
461 if string(s) != "abcd" {
462 t.Errorf("DecodeString = %q; want abcd", s)
463 }
464 }
465
466 func TestDecoderIssue15656(t *testing.T) {
467 _, err := StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDB==")
468 want := CorruptInputError(22)
469 if !reflect.DeepEqual(want, err) {
470 t.Errorf("Error = %v; want CorruptInputError(22)", err)
471 }
472 _, err = StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDA==")
473 if err != nil {
474 t.Errorf("Error = %v; want nil", err)
475 }
476 _, err = StdEncoding.DecodeString("WvLTlMrX9NpYDQlEIFlnDB==")
477 if err != nil {
478 t.Errorf("Error = %v; want nil", err)
479 }
480 }
481
482 func BenchmarkEncodeToString(b *testing.B) {
483 data := make([]byte, 8192)
484 b.SetBytes(int64(len(data)))
485 for i := 0; i < b.N; i++ {
486 StdEncoding.EncodeToString(data)
487 }
488 }
489
490 func BenchmarkDecodeString(b *testing.B) {
491 sizes := []int{2, 4, 8, 64, 8192}
492 benchFunc := func(b *testing.B, benchSize int) {
493 data := StdEncoding.EncodeToString(make([]byte, benchSize))
494 b.SetBytes(int64(len(data)))
495 b.ResetTimer()
496 for i := 0; i < b.N; i++ {
497 StdEncoding.DecodeString(data)
498 }
499 }
500 for _, size := range sizes {
501 b.Run(fmt.Sprintf("%d", size), func(b *testing.B) {
502 benchFunc(b, size)
503 })
504 }
505 }
506
507 func TestDecoderRaw(t *testing.T) {
508 source := "AAAAAA"
509 want := []byte{0, 0, 0, 0}
510
511
512 dec1, err := RawURLEncoding.DecodeString(source)
513 if err != nil || !bytes.Equal(dec1, want) {
514 t.Errorf("RawURLEncoding.DecodeString(%q) = %x, %v, want %x, nil", source, dec1, err, want)
515 }
516
517
518 r := NewDecoder(RawURLEncoding, bytes.NewReader([]byte(source)))
519 dec2, err := io.ReadAll(io.LimitReader(r, 100))
520 if err != nil || !bytes.Equal(dec2, want) {
521 t.Errorf("reading NewDecoder(RawURLEncoding, %q) = %x, %v, want %x, nil", source, dec2, err, want)
522 }
523
524
525 r = NewDecoder(URLEncoding, bytes.NewReader([]byte(source+"==")))
526 dec3, err := io.ReadAll(r)
527 if err != nil || !bytes.Equal(dec3, want) {
528 t.Errorf("reading NewDecoder(URLEncoding, %q) = %x, %v, want %x, nil", source+"==", dec3, err, want)
529 }
530 }
531
View as plain text