1
2
3
4
5 package sumdb
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "path"
12 "strings"
13 "sync"
14 "sync/atomic"
15
16 "golang.org/x/mod/module"
17 "golang.org/x/mod/sumdb/note"
18 "golang.org/x/mod/sumdb/tlog"
19 )
20
21
22
23
24 type ClientOps interface {
25
26
27
28
29
30
31 ReadRemote(path string) ([]byte, error)
32
33
34
35
36
37
38
39
40
41
42 ReadConfig(file string) ([]byte, error)
43
44
45
46
47
48
49
50 WriteConfig(file string, old, new []byte) error
51
52
53
54
55
56
57 ReadCache(file string) ([]byte, error)
58
59
60 WriteCache(file string, data []byte)
61
62
63 Log(msg string)
64
65
66
67
68
69 SecurityError(msg string)
70 }
71
72
73 var ErrWriteConflict = errors.New("write conflict")
74
75
76 var ErrSecurity = errors.New("security error: misbehaving server")
77
78
79
80 type Client struct {
81 ops ClientOps
82
83 didLookup uint32
84
85
86 initOnce sync.Once
87 initErr error
88 name string
89 verifiers note.Verifiers
90 tileReader tileReader
91 tileHeight int
92 nosumdb string
93
94 record parCache
95 tileCache parCache
96
97 latestMu sync.Mutex
98 latest tlog.Tree
99 latestMsg []byte
100
101 tileSavedMu sync.Mutex
102 tileSaved map[tlog.Tile]bool
103 }
104
105
106 func NewClient(ops ClientOps) *Client {
107 return &Client{
108 ops: ops,
109 }
110 }
111
112
113
114 func (c *Client) init() error {
115 c.initOnce.Do(c.initWork)
116 return c.initErr
117 }
118
119
120 func (c *Client) initWork() {
121 defer func() {
122 if c.initErr != nil {
123 c.initErr = fmt.Errorf("initializing sumdb.Client: %v", c.initErr)
124 }
125 }()
126
127 c.tileReader.c = c
128 if c.tileHeight == 0 {
129 c.tileHeight = 8
130 }
131 c.tileSaved = make(map[tlog.Tile]bool)
132
133 vkey, err := c.ops.ReadConfig("key")
134 if err != nil {
135 c.initErr = err
136 return
137 }
138 verifier, err := note.NewVerifier(strings.TrimSpace(string(vkey)))
139 if err != nil {
140 c.initErr = err
141 return
142 }
143 c.verifiers = note.VerifierList(verifier)
144 c.name = verifier.Name()
145
146 data, err := c.ops.ReadConfig(c.name + "/latest")
147 if err != nil {
148 c.initErr = err
149 return
150 }
151 if err := c.mergeLatest(data); err != nil {
152 c.initErr = err
153 return
154 }
155 }
156
157
158
159
160
161
162 func (c *Client) SetTileHeight(height int) {
163 if atomic.LoadUint32(&c.didLookup) != 0 {
164 panic("SetTileHeight used after Lookup")
165 }
166 if height <= 0 {
167 panic("invalid call to SetTileHeight")
168 }
169 if c.tileHeight != 0 {
170 panic("multiple calls to SetTileHeight")
171 }
172 c.tileHeight = height
173 }
174
175
176
177
178
179
180 func (c *Client) SetGONOSUMDB(list string) {
181 if atomic.LoadUint32(&c.didLookup) != 0 {
182 panic("SetGONOSUMDB used after Lookup")
183 }
184 if c.nosumdb != "" {
185 panic("multiple calls to SetGONOSUMDB")
186 }
187 c.nosumdb = list
188 }
189
190
191
192
193 var ErrGONOSUMDB = errors.New("skipped (listed in GONOSUMDB)")
194
195 func (c *Client) skip(target string) bool {
196 return globsMatchPath(c.nosumdb, target)
197 }
198
199
200
201
202
203 func globsMatchPath(globs, target string) bool {
204 for globs != "" {
205
206 var glob string
207 if i := strings.Index(globs, ","); i >= 0 {
208 glob, globs = globs[:i], globs[i+1:]
209 } else {
210 glob, globs = globs, ""
211 }
212 if glob == "" {
213 continue
214 }
215
216
217
218
219 n := strings.Count(glob, "/")
220 prefix := target
221
222 for i := 0; i < len(target); i++ {
223 if target[i] == '/' {
224 if n == 0 {
225 prefix = target[:i]
226 break
227 }
228 n--
229 }
230 }
231 if n > 0 {
232
233 continue
234 }
235 matched, _ := path.Match(glob, prefix)
236 if matched {
237 return true
238 }
239 }
240 return false
241 }
242
243
244
245
246 func (c *Client) Lookup(path, vers string) (lines []string, err error) {
247 atomic.StoreUint32(&c.didLookup, 1)
248
249 if c.skip(path) {
250 return nil, ErrGONOSUMDB
251 }
252
253 defer func() {
254 if err != nil {
255 err = fmt.Errorf("%s@%s: %v", path, vers, err)
256 }
257 }()
258
259 if err := c.init(); err != nil {
260 return nil, err
261 }
262
263
264 epath, err := module.EscapePath(path)
265 if err != nil {
266 return nil, err
267 }
268 evers, err := module.EscapeVersion(strings.TrimSuffix(vers, "/go.mod"))
269 if err != nil {
270 return nil, err
271 }
272 remotePath := "/lookup/" + epath + "@" + evers
273 file := c.name + remotePath
274
275
276
277
278
279
280 type cached struct {
281 data []byte
282 err error
283 }
284 result := c.record.Do(file, func() interface{} {
285
286 writeCache := false
287 data, err := c.ops.ReadCache(file)
288 if err != nil {
289 data, err = c.ops.ReadRemote(remotePath)
290 if err != nil {
291 return cached{nil, err}
292 }
293 writeCache = true
294 }
295
296
297 id, text, treeMsg, err := tlog.ParseRecord(data)
298 if err != nil {
299 return cached{nil, err}
300 }
301 if err := c.mergeLatest(treeMsg); err != nil {
302 return cached{nil, err}
303 }
304 if err := c.checkRecord(id, text); err != nil {
305 return cached{nil, err}
306 }
307
308
309
310 if writeCache {
311 c.ops.WriteCache(file, data)
312 }
313
314 return cached{data, nil}
315 }).(cached)
316 if result.err != nil {
317 return nil, result.err
318 }
319
320
321
322 prefix := path + " " + vers + " "
323 var hashes []string
324 for _, line := range strings.Split(string(result.data), "\n") {
325 if strings.HasPrefix(line, prefix) {
326 hashes = append(hashes, line)
327 }
328 }
329 return hashes, nil
330 }
331
332
333
334
335
336
337
338
339
340
341 func (c *Client) mergeLatest(msg []byte) error {
342
343 when, err := c.mergeLatestMem(msg)
344 if err != nil {
345 return err
346 }
347 if when != msgFuture {
348
349
350 return nil
351 }
352
353
354
355
356
357 for {
358 msg, err := c.ops.ReadConfig(c.name + "/latest")
359 if err != nil {
360 return err
361 }
362 when, err := c.mergeLatestMem(msg)
363 if err != nil {
364 return err
365 }
366 if when != msgPast {
367
368
369 return nil
370 }
371
372
373 c.latestMu.Lock()
374 latestMsg := c.latestMsg
375 c.latestMu.Unlock()
376 if err := c.ops.WriteConfig(c.name+"/latest", msg, latestMsg); err != ErrWriteConflict {
377
378 return err
379 }
380 }
381 }
382
383 const (
384 msgPast = 1 + iota
385 msgNow
386 msgFuture
387 )
388
389
390
391
392
393
394
395
396
397 func (c *Client) mergeLatestMem(msg []byte) (when int, err error) {
398 if len(msg) == 0 {
399
400 c.latestMu.Lock()
401 latest := c.latest
402 c.latestMu.Unlock()
403 if latest.N == 0 {
404 return msgNow, nil
405 }
406 return msgPast, nil
407 }
408
409 note, err := note.Open(msg, c.verifiers)
410 if err != nil {
411 return 0, fmt.Errorf("reading tree note: %v\nnote:\n%s", err, msg)
412 }
413 tree, err := tlog.ParseTree([]byte(note.Text))
414 if err != nil {
415 return 0, fmt.Errorf("reading tree: %v\ntree:\n%s", err, note.Text)
416 }
417
418
419
420
421 c.latestMu.Lock()
422 latest := c.latest
423 latestMsg := c.latestMsg
424 c.latestMu.Unlock()
425
426 for {
427
428 if tree.N <= latest.N {
429 if err := c.checkTrees(tree, msg, latest, latestMsg); err != nil {
430 return 0, err
431 }
432 if tree.N < latest.N {
433 return msgPast, nil
434 }
435 return msgNow, nil
436 }
437
438
439 if err := c.checkTrees(latest, latestMsg, tree, msg); err != nil {
440 return 0, err
441 }
442
443
444
445 c.latestMu.Lock()
446 installed := false
447 if c.latest == latest {
448 installed = true
449 c.latest = tree
450 c.latestMsg = msg
451 } else {
452 latest = c.latest
453 latestMsg = c.latestMsg
454 }
455 c.latestMu.Unlock()
456
457 if installed {
458 return msgFuture, nil
459 }
460 }
461 }
462
463
464
465
466
467 func (c *Client) checkTrees(older tlog.Tree, olderNote []byte, newer tlog.Tree, newerNote []byte) error {
468 thr := tlog.TileHashReader(newer, &c.tileReader)
469 h, err := tlog.TreeHash(older.N, thr)
470 if err != nil {
471 if older.N == newer.N {
472 return fmt.Errorf("checking tree#%d: %v", older.N, err)
473 }
474 return fmt.Errorf("checking tree#%d against tree#%d: %v", older.N, newer.N, err)
475 }
476 if h == older.Hash {
477 return nil
478 }
479
480
481
482 var buf bytes.Buffer
483 fmt.Fprintf(&buf, "SECURITY ERROR\n")
484 fmt.Fprintf(&buf, "go.sum database server misbehavior detected!\n\n")
485 indent := func(b []byte) []byte {
486 return bytes.Replace(b, []byte("\n"), []byte("\n\t"), -1)
487 }
488 fmt.Fprintf(&buf, "old database:\n\t%s\n", indent(olderNote))
489 fmt.Fprintf(&buf, "new database:\n\t%s\n", indent(newerNote))
490
491
492
493
494
495
496
497
498
499
500
501 fmt.Fprintf(&buf, "proof of misbehavior:\n\t%v", h)
502 if p, err := tlog.ProveTree(newer.N, older.N, thr); err != nil {
503 fmt.Fprintf(&buf, "\tinternal error: %v\n", err)
504 } else if err := tlog.CheckTree(p, newer.N, newer.Hash, older.N, h); err != nil {
505 fmt.Fprintf(&buf, "\tinternal error: generated inconsistent proof\n")
506 } else {
507 for _, h := range p {
508 fmt.Fprintf(&buf, "\n\t%v", h)
509 }
510 }
511 c.ops.SecurityError(buf.String())
512 return ErrSecurity
513 }
514
515
516 func (c *Client) checkRecord(id int64, data []byte) error {
517 c.latestMu.Lock()
518 latest := c.latest
519 c.latestMu.Unlock()
520
521 if id >= latest.N {
522 return fmt.Errorf("cannot validate record %d in tree of size %d", id, latest.N)
523 }
524 hashes, err := tlog.TileHashReader(latest, &c.tileReader).ReadHashes([]int64{tlog.StoredHashIndex(0, id)})
525 if err != nil {
526 return err
527 }
528 if hashes[0] == tlog.RecordHash(data) {
529 return nil
530 }
531 return fmt.Errorf("cannot authenticate record data in server response")
532 }
533
534
535
536
537 type tileReader struct {
538 c *Client
539 }
540
541 func (r *tileReader) Height() int {
542 return r.c.tileHeight
543 }
544
545
546
547 func (r *tileReader) ReadTiles(tiles []tlog.Tile) ([][]byte, error) {
548
549 data := make([][]byte, len(tiles))
550 errs := make([]error, len(tiles))
551 var wg sync.WaitGroup
552 for i, tile := range tiles {
553 wg.Add(1)
554 go func(i int, tile tlog.Tile) {
555 defer wg.Done()
556 data[i], errs[i] = r.c.readTile(tile)
557 }(i, tile)
558 }
559 wg.Wait()
560
561 for _, err := range errs {
562 if err != nil {
563 return nil, err
564 }
565 }
566
567 return data, nil
568 }
569
570
571 func (c *Client) tileCacheKey(tile tlog.Tile) string {
572 return c.name + "/" + tile.Path()
573 }
574
575
576 func (c *Client) tileRemotePath(tile tlog.Tile) string {
577 return "/" + tile.Path()
578 }
579
580
581 func (c *Client) readTile(tile tlog.Tile) ([]byte, error) {
582 type cached struct {
583 data []byte
584 err error
585 }
586
587 result := c.tileCache.Do(tile, func() interface{} {
588
589 data, err := c.ops.ReadCache(c.tileCacheKey(tile))
590 if err == nil {
591 c.markTileSaved(tile)
592 return cached{data, nil}
593 }
594
595
596
597
598 full := tile
599 full.W = 1 << uint(tile.H)
600 if tile != full {
601 data, err := c.ops.ReadCache(c.tileCacheKey(full))
602 if err == nil {
603 c.markTileSaved(tile)
604 return cached{data[:len(data)/full.W*tile.W], nil}
605 }
606 }
607
608
609 data, err = c.ops.ReadRemote(c.tileRemotePath(tile))
610 if err == nil {
611 return cached{data, nil}
612 }
613
614
615
616
617
618 if tile != full {
619 data, err := c.ops.ReadRemote(c.tileRemotePath(full))
620 if err == nil {
621
622
623
624
625
626 return cached{data[:len(data)/full.W*tile.W], nil}
627 }
628 }
629
630
631
632 return cached{nil, err}
633 }).(cached)
634
635 return result.data, result.err
636 }
637
638
639
640 func (c *Client) markTileSaved(tile tlog.Tile) {
641 c.tileSavedMu.Lock()
642 c.tileSaved[tile] = true
643 c.tileSavedMu.Unlock()
644 }
645
646
647 func (r *tileReader) SaveTiles(tiles []tlog.Tile, data [][]byte) {
648 c := r.c
649
650
651
652 save := make([]bool, len(tiles))
653 c.tileSavedMu.Lock()
654 for i, tile := range tiles {
655 if !c.tileSaved[tile] {
656 save[i] = true
657 c.tileSaved[tile] = true
658 }
659 }
660 c.tileSavedMu.Unlock()
661
662 for i, tile := range tiles {
663 if save[i] {
664
665
666
667
668 c.ops.WriteCache(c.name+"/"+tile.Path(), data[i])
669 }
670 }
671 }
672
View as plain text