1
2
3
4
5
6
7
8
9
10
11
12 package tlog
13
14 import (
15 "crypto/sha256"
16 "encoding/base64"
17 "errors"
18 "fmt"
19 "math/bits"
20 )
21
22
23 type Hash [HashSize]byte
24
25
26 const HashSize = 32
27
28
29 func (h Hash) String() string {
30 return base64.StdEncoding.EncodeToString(h[:])
31 }
32
33
34 func (h Hash) MarshalJSON() ([]byte, error) {
35 return []byte(`"` + h.String() + `"`), nil
36 }
37
38
39 func (h *Hash) UnmarshalJSON(data []byte) error {
40 if len(data) != 1+44+1 || data[0] != '"' || data[len(data)-2] != '=' || data[len(data)-1] != '"' {
41 return errors.New("cannot decode hash")
42 }
43
44
45
46
47
48
49
50
51 var tmp Hash
52 n, err := base64.RawStdEncoding.Decode(tmp[:], data[1:len(data)-2])
53 if err != nil || n != HashSize {
54 return errors.New("cannot decode hash")
55 }
56 *h = tmp
57 return nil
58 }
59
60
61 func ParseHash(s string) (Hash, error) {
62 data, err := base64.StdEncoding.DecodeString(s)
63 if err != nil || len(data) != HashSize {
64 return Hash{}, fmt.Errorf("malformed hash")
65 }
66 var h Hash
67 copy(h[:], data)
68 return h, nil
69 }
70
71
72
73 func maxpow2(n int64) (k int64, l int) {
74 l = 0
75 for 1<<uint(l+1) < n {
76 l++
77 }
78 return 1 << uint(l), l
79 }
80
81 var zeroPrefix = []byte{0x00}
82
83
84 func RecordHash(data []byte) Hash {
85
86
87 h := sha256.New()
88 h.Write(zeroPrefix)
89 h.Write(data)
90 var h1 Hash
91 h.Sum(h1[:0])
92 return h1
93 }
94
95
96 func NodeHash(left, right Hash) Hash {
97
98
99
100
101 var buf [1 + HashSize + HashSize]byte
102 buf[0] = 0x01
103 copy(buf[1:], left[:])
104 copy(buf[1+HashSize:], right[:])
105 return sha256.Sum256(buf[:])
106 }
107
108
109
110
111
112
113
114
115
116
117
118 func StoredHashIndex(level int, n int64) int64 {
119
120
121
122 for l := level; l > 0; l-- {
123 n = 2*n + 1
124 }
125
126
127 i := int64(0)
128 for ; n > 0; n >>= 1 {
129 i += n
130 }
131
132 return i + int64(level)
133 }
134
135
136
137 func SplitStoredHashIndex(index int64) (level int, n int64) {
138
139
140
141 n = index / 2
142 indexN := StoredHashIndex(0, n)
143 if indexN > index {
144 panic("bad math")
145 }
146 for {
147
148 x := indexN + 1 + int64(bits.TrailingZeros64(uint64(n+1)))
149 if x > index {
150 break
151 }
152 n++
153 indexN = x
154 }
155
156
157 level = int(index - indexN)
158 return level, n >> uint(level)
159 }
160
161
162
163 func StoredHashCount(n int64) int64 {
164 if n == 0 {
165 return 0
166 }
167
168 numHash := StoredHashIndex(0, n-1) + 1
169
170 for i := uint64(n - 1); i&1 != 0; i >>= 1 {
171 numHash++
172 }
173 return numHash
174 }
175
176
177
178
179
180
181
182
183 func StoredHashes(n int64, data []byte, r HashReader) ([]Hash, error) {
184 return StoredHashesForRecordHash(n, RecordHash(data), r)
185 }
186
187
188
189 func StoredHashesForRecordHash(n int64, h Hash, r HashReader) ([]Hash, error) {
190
191 hashes := []Hash{h}
192
193
194
195
196 m := int(bits.TrailingZeros64(uint64(n + 1)))
197 indexes := make([]int64, m)
198 for i := 0; i < m; i++ {
199
200
201 indexes[m-1-i] = StoredHashIndex(i, n>>uint(i)-1)
202 }
203
204
205 old, err := r.ReadHashes(indexes)
206 if err != nil {
207 return nil, err
208 }
209 if len(old) != len(indexes) {
210 return nil, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(old))
211 }
212
213
214 for i := 0; i < m; i++ {
215 h = NodeHash(old[m-1-i], h)
216 hashes = append(hashes, h)
217 }
218 return hashes, nil
219 }
220
221
222 type HashReader interface {
223
224
225
226
227
228 ReadHashes(indexes []int64) ([]Hash, error)
229 }
230
231
232 type HashReaderFunc func([]int64) ([]Hash, error)
233
234 func (f HashReaderFunc) ReadHashes(indexes []int64) ([]Hash, error) {
235 return f(indexes)
236 }
237
238
239
240
241
242
243 func TreeHash(n int64, r HashReader) (Hash, error) {
244 if n == 0 {
245 return Hash{}, nil
246 }
247 indexes := subTreeIndex(0, n, nil)
248 hashes, err := r.ReadHashes(indexes)
249 if err != nil {
250 return Hash{}, err
251 }
252 if len(hashes) != len(indexes) {
253 return Hash{}, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(hashes))
254 }
255 hash, hashes := subTreeHash(0, n, hashes)
256 if len(hashes) != 0 {
257 panic("tlog: bad index math in TreeHash")
258 }
259 return hash, nil
260 }
261
262
263
264
265
266 func subTreeIndex(lo, hi int64, need []int64) []int64 {
267
268 for lo < hi {
269 k, level := maxpow2(hi - lo + 1)
270 if lo&(k-1) != 0 {
271 panic("tlog: bad math in subTreeIndex")
272 }
273 need = append(need, StoredHashIndex(level, lo>>uint(level)))
274 lo += k
275 }
276 return need
277 }
278
279
280
281
282
283 func subTreeHash(lo, hi int64, hashes []Hash) (Hash, []Hash) {
284
285
286
287
288 numTree := 0
289 for lo < hi {
290 k, _ := maxpow2(hi - lo + 1)
291 if lo&(k-1) != 0 || lo >= hi {
292 panic("tlog: bad math in subTreeHash")
293 }
294 numTree++
295 lo += k
296 }
297
298 if len(hashes) < numTree {
299 panic("tlog: bad index math in subTreeHash")
300 }
301
302
303 h := hashes[numTree-1]
304 for i := numTree - 2; i >= 0; i-- {
305 h = NodeHash(hashes[i], h)
306 }
307 return h, hashes[numTree:]
308 }
309
310
311
312 type RecordProof []Hash
313
314
315 func ProveRecord(t, n int64, r HashReader) (RecordProof, error) {
316 if t < 0 || n < 0 || n >= t {
317 return nil, fmt.Errorf("tlog: invalid inputs in ProveRecord")
318 }
319 indexes := leafProofIndex(0, t, n, nil)
320 if len(indexes) == 0 {
321 return RecordProof{}, nil
322 }
323 hashes, err := r.ReadHashes(indexes)
324 if err != nil {
325 return nil, err
326 }
327 if len(hashes) != len(indexes) {
328 return nil, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(hashes))
329 }
330
331 p, hashes := leafProof(0, t, n, hashes)
332 if len(hashes) != 0 {
333 panic("tlog: bad index math in ProveRecord")
334 }
335 return p, nil
336 }
337
338
339
340
341
342 func leafProofIndex(lo, hi, n int64, need []int64) []int64 {
343
344 if !(lo <= n && n < hi) {
345 panic("tlog: bad math in leafProofIndex")
346 }
347 if lo+1 == hi {
348 return need
349 }
350 if k, _ := maxpow2(hi - lo); n < lo+k {
351 need = leafProofIndex(lo, lo+k, n, need)
352 need = subTreeIndex(lo+k, hi, need)
353 } else {
354 need = subTreeIndex(lo, lo+k, need)
355 need = leafProofIndex(lo+k, hi, n, need)
356 }
357 return need
358 }
359
360
361
362
363 func leafProof(lo, hi, n int64, hashes []Hash) (RecordProof, []Hash) {
364
365 if !(lo <= n && n < hi) {
366 panic("tlog: bad math in leafProof")
367 }
368
369 if lo+1 == hi {
370
371
372 return RecordProof{}, hashes
373 }
374
375
376
377 var p RecordProof
378 var th Hash
379 if k, _ := maxpow2(hi - lo); n < lo+k {
380
381 p, hashes = leafProof(lo, lo+k, n, hashes)
382 th, hashes = subTreeHash(lo+k, hi, hashes)
383 } else {
384
385 th, hashes = subTreeHash(lo, lo+k, hashes)
386 p, hashes = leafProof(lo+k, hi, n, hashes)
387 }
388 return append(p, th), hashes
389 }
390
391 var errProofFailed = errors.New("invalid transparency proof")
392
393
394
395 func CheckRecord(p RecordProof, t int64, th Hash, n int64, h Hash) error {
396 if t < 0 || n < 0 || n >= t {
397 return fmt.Errorf("tlog: invalid inputs in CheckRecord")
398 }
399 th2, err := runRecordProof(p, 0, t, n, h)
400 if err != nil {
401 return err
402 }
403 if th2 == th {
404 return nil
405 }
406 return errProofFailed
407 }
408
409
410
411
412 func runRecordProof(p RecordProof, lo, hi, n int64, leafHash Hash) (Hash, error) {
413
414 if !(lo <= n && n < hi) {
415 panic("tlog: bad math in runRecordProof")
416 }
417
418 if lo+1 == hi {
419
420
421 if len(p) != 0 {
422 return Hash{}, errProofFailed
423 }
424 return leafHash, nil
425 }
426
427 if len(p) == 0 {
428 return Hash{}, errProofFailed
429 }
430
431 k, _ := maxpow2(hi - lo)
432 if n < lo+k {
433 th, err := runRecordProof(p[:len(p)-1], lo, lo+k, n, leafHash)
434 if err != nil {
435 return Hash{}, err
436 }
437 return NodeHash(th, p[len(p)-1]), nil
438 } else {
439 th, err := runRecordProof(p[:len(p)-1], lo+k, hi, n, leafHash)
440 if err != nil {
441 return Hash{}, err
442 }
443 return NodeHash(p[len(p)-1], th), nil
444 }
445 }
446
447
448
449
450 type TreeProof []Hash
451
452
453
454 func ProveTree(t, n int64, h HashReader) (TreeProof, error) {
455 if t < 1 || n < 1 || n > t {
456 return nil, fmt.Errorf("tlog: invalid inputs in ProveTree")
457 }
458 indexes := treeProofIndex(0, t, n, nil)
459 if len(indexes) == 0 {
460 return TreeProof{}, nil
461 }
462 hashes, err := h.ReadHashes(indexes)
463 if err != nil {
464 return nil, err
465 }
466 if len(hashes) != len(indexes) {
467 return nil, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(hashes))
468 }
469
470 p, hashes := treeProof(0, t, n, hashes)
471 if len(hashes) != 0 {
472 panic("tlog: bad index math in ProveTree")
473 }
474 return p, nil
475 }
476
477
478
479
480 func treeProofIndex(lo, hi, n int64, need []int64) []int64 {
481
482 if !(lo < n && n <= hi) {
483 panic("tlog: bad math in treeProofIndex")
484 }
485
486 if n == hi {
487 if lo == 0 {
488 return need
489 }
490 return subTreeIndex(lo, hi, need)
491 }
492
493 if k, _ := maxpow2(hi - lo); n <= lo+k {
494 need = treeProofIndex(lo, lo+k, n, need)
495 need = subTreeIndex(lo+k, hi, need)
496 } else {
497 need = subTreeIndex(lo, lo+k, need)
498 need = treeProofIndex(lo+k, hi, n, need)
499 }
500 return need
501 }
502
503
504
505
506 func treeProof(lo, hi, n int64, hashes []Hash) (TreeProof, []Hash) {
507
508 if !(lo < n && n <= hi) {
509 panic("tlog: bad math in treeProof")
510 }
511
512
513 if n == hi {
514 if lo == 0 {
515
516
517 return TreeProof{}, hashes
518 }
519 th, hashes := subTreeHash(lo, hi, hashes)
520 return TreeProof{th}, hashes
521 }
522
523
524
525 var p TreeProof
526 var th Hash
527 if k, _ := maxpow2(hi - lo); n <= lo+k {
528
529 p, hashes = treeProof(lo, lo+k, n, hashes)
530 th, hashes = subTreeHash(lo+k, hi, hashes)
531 } else {
532
533 th, hashes = subTreeHash(lo, lo+k, hashes)
534 p, hashes = treeProof(lo+k, hi, n, hashes)
535 }
536 return append(p, th), hashes
537 }
538
539
540
541 func CheckTree(p TreeProof, t int64, th Hash, n int64, h Hash) error {
542 if t < 1 || n < 1 || n > t {
543 return fmt.Errorf("tlog: invalid inputs in CheckTree")
544 }
545 h2, th2, err := runTreeProof(p, 0, t, n, h)
546 if err != nil {
547 return err
548 }
549 if th2 == th && h2 == h {
550 return nil
551 }
552 return errProofFailed
553 }
554
555
556
557
558
559 func runTreeProof(p TreeProof, lo, hi, n int64, old Hash) (Hash, Hash, error) {
560
561 if !(lo < n && n <= hi) {
562 panic("tlog: bad math in runTreeProof")
563 }
564
565
566 if n == hi {
567 if lo == 0 {
568 if len(p) != 0 {
569 return Hash{}, Hash{}, errProofFailed
570 }
571 return old, old, nil
572 }
573 if len(p) != 1 {
574 return Hash{}, Hash{}, errProofFailed
575 }
576 return p[0], p[0], nil
577 }
578
579 if len(p) == 0 {
580 return Hash{}, Hash{}, errProofFailed
581 }
582
583
584 k, _ := maxpow2(hi - lo)
585 if n <= lo+k {
586 oh, th, err := runTreeProof(p[:len(p)-1], lo, lo+k, n, old)
587 if err != nil {
588 return Hash{}, Hash{}, err
589 }
590 return oh, NodeHash(th, p[len(p)-1]), nil
591 } else {
592 oh, th, err := runTreeProof(p[:len(p)-1], lo+k, hi, n, old)
593 if err != nil {
594 return Hash{}, Hash{}, err
595 }
596 return NodeHash(p[len(p)-1], oh), NodeHash(p[len(p)-1], th), nil
597 }
598 }
599
View as plain text