1
2
3
4
5 package singleflight
6
7 import (
8 "errors"
9 "fmt"
10 "sync"
11 "sync/atomic"
12 "testing"
13 "time"
14 )
15
16 func TestDo(t *testing.T) {
17 var g Group
18 v, err, _ := g.Do("key", func() (any, error) {
19 return "bar", nil
20 })
21 if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
22 t.Errorf("Do = %v; want %v", got, want)
23 }
24 if err != nil {
25 t.Errorf("Do error = %v", err)
26 }
27 }
28
29 func TestDoErr(t *testing.T) {
30 var g Group
31 someErr := errors.New("some error")
32 v, err, _ := g.Do("key", func() (any, error) {
33 return nil, someErr
34 })
35 if err != someErr {
36 t.Errorf("Do error = %v; want someErr %v", err, someErr)
37 }
38 if v != nil {
39 t.Errorf("unexpected non-nil value %#v", v)
40 }
41 }
42
43 func TestDoDupSuppress(t *testing.T) {
44 var g Group
45 var wg1, wg2 sync.WaitGroup
46 c := make(chan string, 1)
47 var calls int32
48 fn := func() (any, error) {
49 if atomic.AddInt32(&calls, 1) == 1 {
50
51 wg1.Done()
52 }
53 v := <-c
54 c <- v
55
56 time.Sleep(10 * time.Millisecond)
57
58 return v, nil
59 }
60
61 const n = 10
62 wg1.Add(1)
63 for i := 0; i < n; i++ {
64 wg1.Add(1)
65 wg2.Add(1)
66 go func() {
67 defer wg2.Done()
68 wg1.Done()
69 v, err, _ := g.Do("key", fn)
70 if err != nil {
71 t.Errorf("Do error: %v", err)
72 return
73 }
74 if s, _ := v.(string); s != "bar" {
75 t.Errorf("Do = %T %v; want %q", v, v, "bar")
76 }
77 }()
78 }
79 wg1.Wait()
80
81
82 c <- "bar"
83 wg2.Wait()
84 if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
85 t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
86 }
87 }
88
View as plain text