Source file src/internal/singleflight/singleflight_test.go

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package singleflight
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"sync"
    11  	"sync/atomic"
    12  	"testing"
    13  	"time"
    14  )
    15  
    16  func TestDo(t *testing.T) {
    17  	var g Group
    18  	v, err, _ := g.Do("key", func() (any, error) {
    19  		return "bar", nil
    20  	})
    21  	if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
    22  		t.Errorf("Do = %v; want %v", got, want)
    23  	}
    24  	if err != nil {
    25  		t.Errorf("Do error = %v", err)
    26  	}
    27  }
    28  
    29  func TestDoErr(t *testing.T) {
    30  	var g Group
    31  	someErr := errors.New("some error")
    32  	v, err, _ := g.Do("key", func() (any, error) {
    33  		return nil, someErr
    34  	})
    35  	if err != someErr {
    36  		t.Errorf("Do error = %v; want someErr %v", err, someErr)
    37  	}
    38  	if v != nil {
    39  		t.Errorf("unexpected non-nil value %#v", v)
    40  	}
    41  }
    42  
    43  func TestDoDupSuppress(t *testing.T) {
    44  	var g Group
    45  	var wg1, wg2 sync.WaitGroup
    46  	c := make(chan string, 1)
    47  	var calls int32
    48  	fn := func() (any, error) {
    49  		if atomic.AddInt32(&calls, 1) == 1 {
    50  			// First invocation.
    51  			wg1.Done()
    52  		}
    53  		v := <-c
    54  		c <- v // pump; make available for any future calls
    55  
    56  		time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
    57  
    58  		return v, nil
    59  	}
    60  
    61  	const n = 10
    62  	wg1.Add(1)
    63  	for i := 0; i < n; i++ {
    64  		wg1.Add(1)
    65  		wg2.Add(1)
    66  		go func() {
    67  			defer wg2.Done()
    68  			wg1.Done()
    69  			v, err, _ := g.Do("key", fn)
    70  			if err != nil {
    71  				t.Errorf("Do error: %v", err)
    72  				return
    73  			}
    74  			if s, _ := v.(string); s != "bar" {
    75  				t.Errorf("Do = %T %v; want %q", v, v, "bar")
    76  			}
    77  		}()
    78  	}
    79  	wg1.Wait()
    80  	// At least one goroutine is in fn now and all of them have at
    81  	// least reached the line before the Do.
    82  	c <- "bar"
    83  	wg2.Wait()
    84  	if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
    85  		t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
    86  	}
    87  }
    88  

View as plain text