1
2
3
4
5
6 package base64
7
8 import (
9 "encoding/binary"
10 "io"
11 "strconv"
12 )
13
14
17
18
19
20
21
22
23 type Encoding struct {
24 encode [64]byte
25 decodeMap [256]byte
26 padChar rune
27 strict bool
28 }
29
30 const (
31 StdPadding rune = '='
32 NoPadding rune = -1
33 )
34
35 const encodeStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
36 const encodeURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
37
38
39
40
41
42
43 func NewEncoding(encoder string) *Encoding {
44 if len(encoder) != 64 {
45 panic("encoding alphabet is not 64-bytes long")
46 }
47 for i := 0; i < len(encoder); i++ {
48 if encoder[i] == '\n' || encoder[i] == '\r' {
49 panic("encoding alphabet contains newline character")
50 }
51 }
52
53 e := new(Encoding)
54 e.padChar = StdPadding
55 copy(e.encode[:], encoder)
56
57 for i := 0; i < len(e.decodeMap); i++ {
58 e.decodeMap[i] = 0xFF
59 }
60 for i := 0; i < len(encoder); i++ {
61 e.decodeMap[encoder[i]] = byte(i)
62 }
63 return e
64 }
65
66
67
68
69
70
71 func (enc Encoding) WithPadding(padding rune) *Encoding {
72 if padding == '\r' || padding == '\n' || padding > 0xff {
73 panic("invalid padding")
74 }
75
76 for i := 0; i < len(enc.encode); i++ {
77 if rune(enc.encode[i]) == padding {
78 panic("padding contained in alphabet")
79 }
80 }
81
82 enc.padChar = padding
83 return &enc
84 }
85
86
87
88
89
90
91
92 func (enc Encoding) Strict() *Encoding {
93 enc.strict = true
94 return &enc
95 }
96
97
98
99 var StdEncoding = NewEncoding(encodeStd)
100
101
102
103 var URLEncoding = NewEncoding(encodeURL)
104
105
106
107
108 var RawStdEncoding = StdEncoding.WithPadding(NoPadding)
109
110
111
112
113 var RawURLEncoding = URLEncoding.WithPadding(NoPadding)
114
115
118
119
120
121
122
123
124
125 func (enc *Encoding) Encode(dst, src []byte) {
126 if len(src) == 0 {
127 return
128 }
129
130
131
132 _ = enc.encode
133
134 di, si := 0, 0
135 n := (len(src) / 3) * 3
136 for si < n {
137
138 val := uint(src[si+0])<<16 | uint(src[si+1])<<8 | uint(src[si+2])
139
140 dst[di+0] = enc.encode[val>>18&0x3F]
141 dst[di+1] = enc.encode[val>>12&0x3F]
142 dst[di+2] = enc.encode[val>>6&0x3F]
143 dst[di+3] = enc.encode[val&0x3F]
144
145 si += 3
146 di += 4
147 }
148
149 remain := len(src) - si
150 if remain == 0 {
151 return
152 }
153
154 val := uint(src[si+0]) << 16
155 if remain == 2 {
156 val |= uint(src[si+1]) << 8
157 }
158
159 dst[di+0] = enc.encode[val>>18&0x3F]
160 dst[di+1] = enc.encode[val>>12&0x3F]
161
162 switch remain {
163 case 2:
164 dst[di+2] = enc.encode[val>>6&0x3F]
165 if enc.padChar != NoPadding {
166 dst[di+3] = byte(enc.padChar)
167 }
168 case 1:
169 if enc.padChar != NoPadding {
170 dst[di+2] = byte(enc.padChar)
171 dst[di+3] = byte(enc.padChar)
172 }
173 }
174 }
175
176
177 func (enc *Encoding) EncodeToString(src []byte) string {
178 buf := make([]byte, enc.EncodedLen(len(src)))
179 enc.Encode(buf, src)
180 return string(buf)
181 }
182
183 type encoder struct {
184 err error
185 enc *Encoding
186 w io.Writer
187 buf [3]byte
188 nbuf int
189 out [1024]byte
190 }
191
192 func (e *encoder) Write(p []byte) (n int, err error) {
193 if e.err != nil {
194 return 0, e.err
195 }
196
197
198 if e.nbuf > 0 {
199 var i int
200 for i = 0; i < len(p) && e.nbuf < 3; i++ {
201 e.buf[e.nbuf] = p[i]
202 e.nbuf++
203 }
204 n += i
205 p = p[i:]
206 if e.nbuf < 3 {
207 return
208 }
209 e.enc.Encode(e.out[:], e.buf[:])
210 if _, e.err = e.w.Write(e.out[:4]); e.err != nil {
211 return n, e.err
212 }
213 e.nbuf = 0
214 }
215
216
217 for len(p) >= 3 {
218 nn := len(e.out) / 4 * 3
219 if nn > len(p) {
220 nn = len(p)
221 nn -= nn % 3
222 }
223 e.enc.Encode(e.out[:], p[:nn])
224 if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil {
225 return n, e.err
226 }
227 n += nn
228 p = p[nn:]
229 }
230
231
232 copy(e.buf[:], p)
233 e.nbuf = len(p)
234 n += len(p)
235 return
236 }
237
238
239
240 func (e *encoder) Close() error {
241
242 if e.err == nil && e.nbuf > 0 {
243 e.enc.Encode(e.out[:], e.buf[:e.nbuf])
244 _, e.err = e.w.Write(e.out[:e.enc.EncodedLen(e.nbuf)])
245 e.nbuf = 0
246 }
247 return e.err
248 }
249
250
251
252
253
254
255 func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
256 return &encoder{enc: enc, w: w}
257 }
258
259
260
261 func (enc *Encoding) EncodedLen(n int) int {
262 if enc.padChar == NoPadding {
263 return (n*8 + 5) / 6
264 }
265 return (n + 2) / 3 * 4
266 }
267
268
271
272 type CorruptInputError int64
273
274 func (e CorruptInputError) Error() string {
275 return "illegal base64 data at input byte " + strconv.FormatInt(int64(e), 10)
276 }
277
278
279
280
281
282
283 func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err error) {
284
285 var dbuf [4]byte
286 dlen := 4
287
288
289 _ = enc.decodeMap
290
291 for j := 0; j < len(dbuf); j++ {
292 if len(src) == si {
293 switch {
294 case j == 0:
295 return si, 0, nil
296 case j == 1, enc.padChar != NoPadding:
297 return si, 0, CorruptInputError(si - j)
298 }
299 dlen = j
300 break
301 }
302 in := src[si]
303 si++
304
305 out := enc.decodeMap[in]
306 if out != 0xff {
307 dbuf[j] = out
308 continue
309 }
310
311 if in == '\n' || in == '\r' {
312 j--
313 continue
314 }
315
316 if rune(in) != enc.padChar {
317 return si, 0, CorruptInputError(si - 1)
318 }
319
320
321 switch j {
322 case 0, 1:
323
324 return si, 0, CorruptInputError(si - 1)
325 case 2:
326
327
328 for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
329 si++
330 }
331 if si == len(src) {
332
333 return si, 0, CorruptInputError(len(src))
334 }
335 if rune(src[si]) != enc.padChar {
336
337 return si, 0, CorruptInputError(si - 1)
338 }
339
340 si++
341 }
342
343
344 for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
345 si++
346 }
347 if si < len(src) {
348
349 err = CorruptInputError(si)
350 }
351 dlen = j
352 break
353 }
354
355
356 val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
357 dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
358 switch dlen {
359 case 4:
360 dst[2] = dbuf[2]
361 dbuf[2] = 0
362 fallthrough
363 case 3:
364 dst[1] = dbuf[1]
365 if enc.strict && dbuf[2] != 0 {
366 return si, 0, CorruptInputError(si - 1)
367 }
368 dbuf[1] = 0
369 fallthrough
370 case 2:
371 dst[0] = dbuf[0]
372 if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
373 return si, 0, CorruptInputError(si - 2)
374 }
375 }
376
377 return si, dlen - 1, err
378 }
379
380
381 func (enc *Encoding) DecodeString(s string) ([]byte, error) {
382 dbuf := make([]byte, enc.DecodedLen(len(s)))
383 n, err := enc.Decode(dbuf, []byte(s))
384 return dbuf[:n], err
385 }
386
387 type decoder struct {
388 err error
389 readErr error
390 enc *Encoding
391 r io.Reader
392 buf [1024]byte
393 nbuf int
394 out []byte
395 outbuf [1024 / 4 * 3]byte
396 }
397
398 func (d *decoder) Read(p []byte) (n int, err error) {
399
400 if len(d.out) > 0 {
401 n = copy(p, d.out)
402 d.out = d.out[n:]
403 return n, nil
404 }
405
406 if d.err != nil {
407 return 0, d.err
408 }
409
410
411
412
413 for d.nbuf < 4 && d.readErr == nil {
414 nn := len(p) / 3 * 4
415 if nn < 4 {
416 nn = 4
417 }
418 if nn > len(d.buf) {
419 nn = len(d.buf)
420 }
421 nn, d.readErr = d.r.Read(d.buf[d.nbuf:nn])
422 d.nbuf += nn
423 }
424
425 if d.nbuf < 4 {
426 if d.enc.padChar == NoPadding && d.nbuf > 0 {
427
428 var nw int
429 nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:d.nbuf])
430 d.nbuf = 0
431 d.out = d.outbuf[:nw]
432 n = copy(p, d.out)
433 d.out = d.out[n:]
434 if n > 0 || len(p) == 0 && len(d.out) > 0 {
435 return n, nil
436 }
437 if d.err != nil {
438 return 0, d.err
439 }
440 }
441 d.err = d.readErr
442 if d.err == io.EOF && d.nbuf > 0 {
443 d.err = io.ErrUnexpectedEOF
444 }
445 return 0, d.err
446 }
447
448
449 nr := d.nbuf / 4 * 4
450 nw := d.nbuf / 4 * 3
451 if nw > len(p) {
452 nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:nr])
453 d.out = d.outbuf[:nw]
454 n = copy(p, d.out)
455 d.out = d.out[n:]
456 } else {
457 n, d.err = d.enc.Decode(p, d.buf[:nr])
458 }
459 d.nbuf -= nr
460 copy(d.buf[:d.nbuf], d.buf[nr:])
461 return n, d.err
462 }
463
464
465
466
467
468
469 func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
470 if len(src) == 0 {
471 return 0, nil
472 }
473
474
475
476
477 _ = enc.decodeMap
478
479 si := 0
480 for strconv.IntSize >= 64 && len(src)-si >= 8 && len(dst)-n >= 8 {
481 src2 := src[si : si+8]
482 if dn, ok := assemble64(
483 enc.decodeMap[src2[0]],
484 enc.decodeMap[src2[1]],
485 enc.decodeMap[src2[2]],
486 enc.decodeMap[src2[3]],
487 enc.decodeMap[src2[4]],
488 enc.decodeMap[src2[5]],
489 enc.decodeMap[src2[6]],
490 enc.decodeMap[src2[7]],
491 ); ok {
492 binary.BigEndian.PutUint64(dst[n:], dn)
493 n += 6
494 si += 8
495 } else {
496 var ninc int
497 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
498 n += ninc
499 if err != nil {
500 return n, err
501 }
502 }
503 }
504
505 for len(src)-si >= 4 && len(dst)-n >= 4 {
506 src2 := src[si : si+4]
507 if dn, ok := assemble32(
508 enc.decodeMap[src2[0]],
509 enc.decodeMap[src2[1]],
510 enc.decodeMap[src2[2]],
511 enc.decodeMap[src2[3]],
512 ); ok {
513 binary.BigEndian.PutUint32(dst[n:], dn)
514 n += 3
515 si += 4
516 } else {
517 var ninc int
518 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
519 n += ninc
520 if err != nil {
521 return n, err
522 }
523 }
524 }
525
526 for si < len(src) {
527 var ninc int
528 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
529 n += ninc
530 if err != nil {
531 return n, err
532 }
533 }
534 return n, err
535 }
536
537
538
539
540 func assemble32(n1, n2, n3, n4 byte) (dn uint32, ok bool) {
541
542
543 if n1|n2|n3|n4 == 0xff {
544 return 0, false
545 }
546 return uint32(n1)<<26 |
547 uint32(n2)<<20 |
548 uint32(n3)<<14 |
549 uint32(n4)<<8,
550 true
551 }
552
553
554
555
556 func assemble64(n1, n2, n3, n4, n5, n6, n7, n8 byte) (dn uint64, ok bool) {
557
558
559 if n1|n2|n3|n4|n5|n6|n7|n8 == 0xff {
560 return 0, false
561 }
562 return uint64(n1)<<58 |
563 uint64(n2)<<52 |
564 uint64(n3)<<46 |
565 uint64(n4)<<40 |
566 uint64(n5)<<34 |
567 uint64(n6)<<28 |
568 uint64(n7)<<22 |
569 uint64(n8)<<16,
570 true
571 }
572
573 type newlineFilteringReader struct {
574 wrapped io.Reader
575 }
576
577 func (r *newlineFilteringReader) Read(p []byte) (int, error) {
578 n, err := r.wrapped.Read(p)
579 for n > 0 {
580 offset := 0
581 for i, b := range p[:n] {
582 if b != '\r' && b != '\n' {
583 if i != offset {
584 p[offset] = b
585 }
586 offset++
587 }
588 }
589 if offset > 0 {
590 return offset, err
591 }
592
593 n, err = r.wrapped.Read(p)
594 }
595 return n, err
596 }
597
598
599 func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
600 return &decoder{enc: enc, r: &newlineFilteringReader{r}}
601 }
602
603
604
605 func (enc *Encoding) DecodedLen(n int) int {
606 if enc.padChar == NoPadding {
607
608 return n * 6 / 8
609 }
610
611 return n / 4 * 3
612 }
613
View as plain text