Source file src/sync/atomic/value_test.go

     1  // Copyright 2014 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package atomic_test
     6  
     7  import (
     8  	"math/rand"
     9  	"runtime"
    10  	"strconv"
    11  	"sync"
    12  	"sync/atomic"
    13  	. "sync/atomic"
    14  	"testing"
    15  )
    16  
    17  func TestValue(t *testing.T) {
    18  	var v Value
    19  	if v.Load() != nil {
    20  		t.Fatal("initial Value is not nil")
    21  	}
    22  	v.Store(42)
    23  	x := v.Load()
    24  	if xx, ok := x.(int); !ok || xx != 42 {
    25  		t.Fatalf("wrong value: got %+v, want 42", x)
    26  	}
    27  	v.Store(84)
    28  	x = v.Load()
    29  	if xx, ok := x.(int); !ok || xx != 84 {
    30  		t.Fatalf("wrong value: got %+v, want 84", x)
    31  	}
    32  }
    33  
    34  func TestValueLarge(t *testing.T) {
    35  	var v Value
    36  	v.Store("foo")
    37  	x := v.Load()
    38  	if xx, ok := x.(string); !ok || xx != "foo" {
    39  		t.Fatalf("wrong value: got %+v, want foo", x)
    40  	}
    41  	v.Store("barbaz")
    42  	x = v.Load()
    43  	if xx, ok := x.(string); !ok || xx != "barbaz" {
    44  		t.Fatalf("wrong value: got %+v, want barbaz", x)
    45  	}
    46  }
    47  
    48  func TestValuePanic(t *testing.T) {
    49  	const nilErr = "sync/atomic: store of nil value into Value"
    50  	const badErr = "sync/atomic: store of inconsistently typed value into Value"
    51  	var v Value
    52  	func() {
    53  		defer func() {
    54  			err := recover()
    55  			if err != nilErr {
    56  				t.Fatalf("inconsistent store panic: got '%v', want '%v'", err, nilErr)
    57  			}
    58  		}()
    59  		v.Store(nil)
    60  	}()
    61  	v.Store(42)
    62  	func() {
    63  		defer func() {
    64  			err := recover()
    65  			if err != badErr {
    66  				t.Fatalf("inconsistent store panic: got '%v', want '%v'", err, badErr)
    67  			}
    68  		}()
    69  		v.Store("foo")
    70  	}()
    71  	func() {
    72  		defer func() {
    73  			err := recover()
    74  			if err != nilErr {
    75  				t.Fatalf("inconsistent store panic: got '%v', want '%v'", err, nilErr)
    76  			}
    77  		}()
    78  		v.Store(nil)
    79  	}()
    80  }
    81  
    82  func TestValueConcurrent(t *testing.T) {
    83  	tests := [][]any{
    84  		{uint16(0), ^uint16(0), uint16(1 + 2<<8), uint16(3 + 4<<8)},
    85  		{uint32(0), ^uint32(0), uint32(1 + 2<<16), uint32(3 + 4<<16)},
    86  		{uint64(0), ^uint64(0), uint64(1 + 2<<32), uint64(3 + 4<<32)},
    87  		{complex(0, 0), complex(1, 2), complex(3, 4), complex(5, 6)},
    88  	}
    89  	p := 4 * runtime.GOMAXPROCS(0)
    90  	N := int(1e5)
    91  	if testing.Short() {
    92  		p /= 2
    93  		N = 1e3
    94  	}
    95  	for _, test := range tests {
    96  		var v Value
    97  		done := make(chan bool, p)
    98  		for i := 0; i < p; i++ {
    99  			go func() {
   100  				r := rand.New(rand.NewSource(rand.Int63()))
   101  				expected := true
   102  			loop:
   103  				for j := 0; j < N; j++ {
   104  					x := test[r.Intn(len(test))]
   105  					v.Store(x)
   106  					x = v.Load()
   107  					for _, x1 := range test {
   108  						if x == x1 {
   109  							continue loop
   110  						}
   111  					}
   112  					t.Logf("loaded unexpected value %+v, want %+v", x, test)
   113  					expected = false
   114  					break
   115  				}
   116  				done <- expected
   117  			}()
   118  		}
   119  		for i := 0; i < p; i++ {
   120  			if !<-done {
   121  				t.FailNow()
   122  			}
   123  		}
   124  	}
   125  }
   126  
   127  func BenchmarkValueRead(b *testing.B) {
   128  	var v Value
   129  	v.Store(new(int))
   130  	b.RunParallel(func(pb *testing.PB) {
   131  		for pb.Next() {
   132  			x := v.Load().(*int)
   133  			if *x != 0 {
   134  				b.Fatalf("wrong value: got %v, want 0", *x)
   135  			}
   136  		}
   137  	})
   138  }
   139  
   140  var Value_SwapTests = []struct {
   141  	init any
   142  	new  any
   143  	want any
   144  	err  any
   145  }{
   146  	{init: nil, new: nil, err: "sync/atomic: swap of nil value into Value"},
   147  	{init: nil, new: true, want: nil, err: nil},
   148  	{init: true, new: "", err: "sync/atomic: swap of inconsistently typed value into Value"},
   149  	{init: true, new: false, want: true, err: nil},
   150  }
   151  
   152  func TestValue_Swap(t *testing.T) {
   153  	for i, tt := range Value_SwapTests {
   154  		t.Run(strconv.Itoa(i), func(t *testing.T) {
   155  			var v Value
   156  			if tt.init != nil {
   157  				v.Store(tt.init)
   158  			}
   159  			defer func() {
   160  				err := recover()
   161  				switch {
   162  				case tt.err == nil && err != nil:
   163  					t.Errorf("should not panic, got %v", err)
   164  				case tt.err != nil && err == nil:
   165  					t.Errorf("should panic %v, got <nil>", tt.err)
   166  				}
   167  			}()
   168  			if got := v.Swap(tt.new); got != tt.want {
   169  				t.Errorf("got %v, want %v", got, tt.want)
   170  			}
   171  			if got := v.Load(); got != tt.new {
   172  				t.Errorf("got %v, want %v", got, tt.new)
   173  			}
   174  		})
   175  	}
   176  }
   177  
   178  func TestValueSwapConcurrent(t *testing.T) {
   179  	var v Value
   180  	var count uint64
   181  	var g sync.WaitGroup
   182  	var m, n uint64 = 10000, 10000
   183  	if testing.Short() {
   184  		m = 1000
   185  		n = 1000
   186  	}
   187  	for i := uint64(0); i < m*n; i += n {
   188  		i := i
   189  		g.Add(1)
   190  		go func() {
   191  			var c uint64
   192  			for new := i; new < i+n; new++ {
   193  				if old := v.Swap(new); old != nil {
   194  					c += old.(uint64)
   195  				}
   196  			}
   197  			atomic.AddUint64(&count, c)
   198  			g.Done()
   199  		}()
   200  	}
   201  	g.Wait()
   202  	if want, got := (m*n-1)*(m*n)/2, count+v.Load().(uint64); got != want {
   203  		t.Errorf("sum from 0 to %d was %d, want %v", m*n-1, got, want)
   204  	}
   205  }
   206  
   207  var heapA, heapB = struct{ uint }{0}, struct{ uint }{0}
   208  
   209  var Value_CompareAndSwapTests = []struct {
   210  	init any
   211  	new  any
   212  	old  any
   213  	want bool
   214  	err  any
   215  }{
   216  	{init: nil, new: nil, old: nil, err: "sync/atomic: compare and swap of nil value into Value"},
   217  	{init: nil, new: true, old: "", err: "sync/atomic: compare and swap of inconsistently typed values into Value"},
   218  	{init: nil, new: true, old: true, want: false, err: nil},
   219  	{init: nil, new: true, old: nil, want: true, err: nil},
   220  	{init: true, new: "", err: "sync/atomic: compare and swap of inconsistently typed value into Value"},
   221  	{init: true, new: true, old: false, want: false, err: nil},
   222  	{init: true, new: true, old: true, want: true, err: nil},
   223  	{init: heapA, new: struct{ uint }{1}, old: heapB, want: true, err: nil},
   224  }
   225  
   226  func TestValue_CompareAndSwap(t *testing.T) {
   227  	for i, tt := range Value_CompareAndSwapTests {
   228  		t.Run(strconv.Itoa(i), func(t *testing.T) {
   229  			var v Value
   230  			if tt.init != nil {
   231  				v.Store(tt.init)
   232  			}
   233  			defer func() {
   234  				err := recover()
   235  				switch {
   236  				case tt.err == nil && err != nil:
   237  					t.Errorf("got %v, wanted no panic", err)
   238  				case tt.err != nil && err == nil:
   239  					t.Errorf("did not panic, want %v", tt.err)
   240  				}
   241  			}()
   242  			if got := v.CompareAndSwap(tt.old, tt.new); got != tt.want {
   243  				t.Errorf("got %v, want %v", got, tt.want)
   244  			}
   245  		})
   246  	}
   247  }
   248  
   249  func TestValueCompareAndSwapConcurrent(t *testing.T) {
   250  	var v Value
   251  	var w sync.WaitGroup
   252  	v.Store(0)
   253  	m, n := 1000, 100
   254  	if testing.Short() {
   255  		m = 100
   256  		n = 100
   257  	}
   258  	for i := 0; i < m; i++ {
   259  		i := i
   260  		w.Add(1)
   261  		go func() {
   262  			for j := i; j < m*n; runtime.Gosched() {
   263  				if v.CompareAndSwap(j, j+1) {
   264  					j += m
   265  				}
   266  			}
   267  			w.Done()
   268  		}()
   269  	}
   270  	w.Wait()
   271  	if stop := v.Load().(int); stop != m*n {
   272  		t.Errorf("did not get to %v, stopped at %v", m*n, stop)
   273  	}
   274  }
   275  

View as plain text