Source file src/internal/singleflight/singleflight_test.go

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package singleflight
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"sync"
    11  	"sync/atomic"
    12  	"testing"
    13  	"time"
    14  )
    15  
    16  func TestDo(t *testing.T) {
    17  	var g Group
    18  	v, err, _ := g.Do("key", func() (any, error) {
    19  		return "bar", nil
    20  	})
    21  	if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
    22  		t.Errorf("Do = %v; want %v", got, want)
    23  	}
    24  	if err != nil {
    25  		t.Errorf("Do error = %v", err)
    26  	}
    27  }
    28  
    29  func TestDoErr(t *testing.T) {
    30  	var g Group
    31  	someErr := errors.New("some error")
    32  	v, err, _ := g.Do("key", func() (any, error) {
    33  		return nil, someErr
    34  	})
    35  	if err != someErr {
    36  		t.Errorf("Do error = %v; want someErr %v", err, someErr)
    37  	}
    38  	if v != nil {
    39  		t.Errorf("unexpected non-nil value %#v", v)
    40  	}
    41  }
    42  
    43  func TestDoDupSuppress(t *testing.T) {
    44  	var g Group
    45  	var wg1, wg2 sync.WaitGroup
    46  	c := make(chan string, 1)
    47  	var calls atomic.Int32
    48  	fn := func() (any, error) {
    49  		if calls.Add(1) == 1 {
    50  			// First invocation.
    51  			wg1.Done()
    52  		}
    53  		v := <-c
    54  		c <- v // pump; make available for any future calls
    55  
    56  		time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
    57  
    58  		return v, nil
    59  	}
    60  
    61  	const n = 10
    62  	wg1.Add(1)
    63  	for i := 0; i < n; i++ {
    64  		wg1.Add(1)
    65  		wg2.Add(1)
    66  		go func() {
    67  			defer wg2.Done()
    68  			wg1.Done()
    69  			v, err, _ := g.Do("key", fn)
    70  			if err != nil {
    71  				t.Errorf("Do error: %v", err)
    72  				return
    73  			}
    74  			if s, _ := v.(string); s != "bar" {
    75  				t.Errorf("Do = %T %v; want %q", v, v, "bar")
    76  			}
    77  		}()
    78  	}
    79  	wg1.Wait()
    80  	// At least one goroutine is in fn now and all of them have at
    81  	// least reached the line before the Do.
    82  	c <- "bar"
    83  	wg2.Wait()
    84  	if got := calls.Load(); got <= 0 || got >= n {
    85  		t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
    86  	}
    87  }
    88  
    89  func TestForgetUnshared(t *testing.T) {
    90  	var g Group
    91  
    92  	var firstStarted, firstFinished sync.WaitGroup
    93  
    94  	firstStarted.Add(1)
    95  	firstFinished.Add(1)
    96  
    97  	key := "key"
    98  	firstCh := make(chan struct{})
    99  	go func() {
   100  		g.Do(key, func() (i interface{}, e error) {
   101  			firstStarted.Done()
   102  			<-firstCh
   103  			return
   104  		})
   105  		firstFinished.Done()
   106  	}()
   107  
   108  	firstStarted.Wait()
   109  	g.ForgetUnshared(key) // from this point no two function using same key should be executed concurrently
   110  
   111  	secondCh := make(chan struct{})
   112  	go func() {
   113  		g.Do(key, func() (i interface{}, e error) {
   114  			// Notify that we started
   115  			secondCh <- struct{}{}
   116  			<-secondCh
   117  			return 2, nil
   118  		})
   119  	}()
   120  
   121  	<-secondCh
   122  
   123  	resultCh := g.DoChan(key, func() (i interface{}, e error) {
   124  		panic("third must not be started")
   125  	})
   126  
   127  	if g.ForgetUnshared(key) {
   128  		t.Errorf("Before first goroutine finished, key %q is shared, should return false", key)
   129  	}
   130  
   131  	close(firstCh)
   132  	firstFinished.Wait()
   133  
   134  	if g.ForgetUnshared(key) {
   135  		t.Errorf("After first goroutine finished, key %q is still shared, should return false", key)
   136  	}
   137  
   138  	secondCh <- struct{}{}
   139  
   140  	if result := <-resultCh; result.Val != 2 {
   141  		t.Errorf("We should receive result produced by second call, expected: 2, got %d", result.Val)
   142  	}
   143  }
   144  
   145  func TestDoAndForgetUnsharedRace(t *testing.T) {
   146  	t.Parallel()
   147  
   148  	var g Group
   149  	key := "key"
   150  	d := time.Millisecond
   151  	for {
   152  		var calls, shared atomic.Int64
   153  		const n = 1000
   154  		var wg sync.WaitGroup
   155  		wg.Add(n)
   156  		for i := 0; i < n; i++ {
   157  			go func() {
   158  				g.Do(key, func() (interface{}, error) {
   159  					time.Sleep(d)
   160  					return calls.Add(1), nil
   161  				})
   162  				if !g.ForgetUnshared(key) {
   163  					shared.Add(1)
   164  				}
   165  				wg.Done()
   166  			}()
   167  		}
   168  		wg.Wait()
   169  
   170  		if calls.Load() != 1 {
   171  			// The goroutines didn't park in g.Do in time,
   172  			// so the key was re-added and may have been shared after the call.
   173  			// Try again with more time to park.
   174  			d *= 2
   175  			continue
   176  		}
   177  
   178  		// All of the Do calls ended up sharing the first
   179  		// invocation, so the key should have been unused
   180  		// (and therefore unshared) when they returned.
   181  		if shared.Load() > 0 {
   182  			t.Errorf("after a single shared Do, ForgetUnshared returned false %d times", shared.Load())
   183  		}
   184  		break
   185  	}
   186  }
   187  

View as plain text