Source file test/typeparam/sets.go

     1  // run
     2  
     3  // Copyright 2021 The Go Authors. All rights reserved.
     4  // Use of this source code is governed by a BSD-style
     5  // license that can be found in the LICENSE file.
     6  
     7  package main
     8  
     9  import (
    10  	"fmt"
    11  	"sort"
    12  )
    13  
    14  // _Equal reports whether two slices are equal: the same length and all
    15  // elements equal. All floating point NaNs are considered equal.
    16  func _SliceEqual[Elem comparable](s1, s2 []Elem) bool {
    17  	if len(s1) != len(s2) {
    18  		return false
    19  	}
    20  	for i, v1 := range s1 {
    21  		v2 := s2[i]
    22  		if v1 != v2 {
    23  			isNaN := func(f Elem) bool { return f != f }
    24  			if !isNaN(v1) || !isNaN(v2) {
    25  				return false
    26  			}
    27  		}
    28  	}
    29  	return true
    30  }
    31  
    32  // A _Set is a set of elements of some type.
    33  type _Set[Elem comparable] struct {
    34  	m map[Elem]struct{}
    35  }
    36  
    37  // _Make makes a new set.
    38  func _Make[Elem comparable]() _Set[Elem] {
    39  	return _Set[Elem]{m: make(map[Elem]struct{})}
    40  }
    41  
    42  // Add adds an element to a set.
    43  func (s _Set[Elem]) Add(v Elem) {
    44  	s.m[v] = struct{}{}
    45  }
    46  
    47  // Delete removes an element from a set. If the element is not present
    48  // in the set, this does nothing.
    49  func (s _Set[Elem]) Delete(v Elem) {
    50  	delete(s.m, v)
    51  }
    52  
    53  // Contains reports whether v is in the set.
    54  func (s _Set[Elem]) Contains(v Elem) bool {
    55  	_, ok := s.m[v]
    56  	return ok
    57  }
    58  
    59  // Len returns the number of elements in the set.
    60  func (s _Set[Elem]) Len() int {
    61  	return len(s.m)
    62  }
    63  
    64  // Values returns the values in the set.
    65  // The values will be in an indeterminate order.
    66  func (s _Set[Elem]) Values() []Elem {
    67  	r := make([]Elem, 0, len(s.m))
    68  	for v := range s.m {
    69  		r = append(r, v)
    70  	}
    71  	return r
    72  }
    73  
    74  // _Equal reports whether two sets contain the same elements.
    75  func _Equal[Elem comparable](s1, s2 _Set[Elem]) bool {
    76  	if len(s1.m) != len(s2.m) {
    77  		return false
    78  	}
    79  	for v1 := range s1.m {
    80  		if !s2.Contains(v1) {
    81  			return false
    82  		}
    83  	}
    84  	return true
    85  }
    86  
    87  // Copy returns a copy of s.
    88  func (s _Set[Elem]) Copy() _Set[Elem] {
    89  	r := _Set[Elem]{m: make(map[Elem]struct{}, len(s.m))}
    90  	for v := range s.m {
    91  		r.m[v] = struct{}{}
    92  	}
    93  	return r
    94  }
    95  
    96  // AddSet adds all the elements of s2 to s.
    97  func (s _Set[Elem]) AddSet(s2 _Set[Elem]) {
    98  	for v := range s2.m {
    99  		s.m[v] = struct{}{}
   100  	}
   101  }
   102  
   103  // SubSet removes all elements in s2 from s.
   104  // Values in s2 that are not in s are ignored.
   105  func (s _Set[Elem]) SubSet(s2 _Set[Elem]) {
   106  	for v := range s2.m {
   107  		delete(s.m, v)
   108  	}
   109  }
   110  
   111  // Intersect removes all elements from s that are not present in s2.
   112  // Values in s2 that are not in s are ignored.
   113  func (s _Set[Elem]) Intersect(s2 _Set[Elem]) {
   114  	for v := range s.m {
   115  		if !s2.Contains(v) {
   116  			delete(s.m, v)
   117  		}
   118  	}
   119  }
   120  
   121  // Iterate calls f on every element in the set.
   122  func (s _Set[Elem]) Iterate(f func(Elem)) {
   123  	for v := range s.m {
   124  		f(v)
   125  	}
   126  }
   127  
   128  // Filter deletes any elements from s for which f returns false.
   129  func (s _Set[Elem]) Filter(f func(Elem) bool) {
   130  	for v := range s.m {
   131  		if !f(v) {
   132  			delete(s.m, v)
   133  		}
   134  	}
   135  }
   136  
   137  func TestSet() {
   138  	s1 := _Make[int]()
   139  	if got := s1.Len(); got != 0 {
   140  		panic(fmt.Sprintf("Len of empty set = %d, want 0", got))
   141  	}
   142  	s1.Add(1)
   143  	s1.Add(1)
   144  	s1.Add(1)
   145  	if got := s1.Len(); got != 1 {
   146  		panic(fmt.Sprintf("(%v).Len() == %d, want 1", s1, got))
   147  	}
   148  	s1.Add(2)
   149  	s1.Add(3)
   150  	s1.Add(4)
   151  	if got := s1.Len(); got != 4 {
   152  		panic(fmt.Sprintf("(%v).Len() == %d, want 4", s1, got))
   153  	}
   154  	if !s1.Contains(1) {
   155  		panic(fmt.Sprintf("(%v).Contains(1) == false, want true", s1))
   156  	}
   157  	if s1.Contains(5) {
   158  		panic(fmt.Sprintf("(%v).Contains(5) == true, want false", s1))
   159  	}
   160  	vals := s1.Values()
   161  	sort.Ints(vals)
   162  	w1 := []int{1, 2, 3, 4}
   163  	if !_SliceEqual(vals, w1) {
   164  		panic(fmt.Sprintf("(%v).Values() == %v, want %v", s1, vals, w1))
   165  	}
   166  }
   167  
   168  func TestEqual() {
   169  	s1 := _Make[string]()
   170  	s2 := _Make[string]()
   171  	if !_Equal(s1, s2) {
   172  		panic(fmt.Sprintf("_Equal(%v, %v) = false, want true", s1, s2))
   173  	}
   174  	s1.Add("hello")
   175  	s1.Add("world")
   176  	if got := s1.Len(); got != 2 {
   177  		panic(fmt.Sprintf("(%v).Len() == %d, want 2", s1, got))
   178  	}
   179  	if _Equal(s1, s2) {
   180  		panic(fmt.Sprintf("_Equal(%v, %v) = true, want false", s1, s2))
   181  	}
   182  }
   183  
   184  func TestCopy() {
   185  	s1 := _Make[float64]()
   186  	s1.Add(0)
   187  	s2 := s1.Copy()
   188  	if !_Equal(s1, s2) {
   189  		panic(fmt.Sprintf("_Equal(%v, %v) = false, want true", s1, s2))
   190  	}
   191  	s1.Add(1)
   192  	if _Equal(s1, s2) {
   193  		panic(fmt.Sprintf("_Equal(%v, %v) = true, want false", s1, s2))
   194  	}
   195  }
   196  
   197  func TestAddSet() {
   198  	s1 := _Make[int]()
   199  	s1.Add(1)
   200  	s1.Add(2)
   201  	s2 := _Make[int]()
   202  	s2.Add(2)
   203  	s2.Add(3)
   204  	s1.AddSet(s2)
   205  	if got := s1.Len(); got != 3 {
   206  		panic(fmt.Sprintf("(%v).Len() == %d, want 3", s1, got))
   207  	}
   208  	s2.Add(1)
   209  	if !_Equal(s1, s2) {
   210  		panic(fmt.Sprintf("_Equal(%v, %v) = false, want true", s1, s2))
   211  	}
   212  }
   213  
   214  func TestSubSet() {
   215  	s1 := _Make[int]()
   216  	s1.Add(1)
   217  	s1.Add(2)
   218  	s2 := _Make[int]()
   219  	s2.Add(2)
   220  	s2.Add(3)
   221  	s1.SubSet(s2)
   222  	if got := s1.Len(); got != 1 {
   223  		panic(fmt.Sprintf("(%v).Len() == %d, want 1", s1, got))
   224  	}
   225  	if vals, want := s1.Values(), []int{1}; !_SliceEqual(vals, want) {
   226  		panic(fmt.Sprintf("after SubSet got %v, want %v", vals, want))
   227  	}
   228  }
   229  
   230  func TestIntersect() {
   231  	s1 := _Make[int]()
   232  	s1.Add(1)
   233  	s1.Add(2)
   234  	s2 := _Make[int]()
   235  	s2.Add(2)
   236  	s2.Add(3)
   237  	s1.Intersect(s2)
   238  	if got := s1.Len(); got != 1 {
   239  		panic(fmt.Sprintf("(%v).Len() == %d, want 1", s1, got))
   240  	}
   241  	if vals, want := s1.Values(), []int{2}; !_SliceEqual(vals, want) {
   242  		panic(fmt.Sprintf("after Intersect got %v, want %v", vals, want))
   243  	}
   244  }
   245  
   246  func TestIterate() {
   247  	s1 := _Make[int]()
   248  	s1.Add(1)
   249  	s1.Add(2)
   250  	s1.Add(3)
   251  	s1.Add(4)
   252  	tot := 0
   253  	s1.Iterate(func(i int) { tot += i })
   254  	if tot != 10 {
   255  		panic(fmt.Sprintf("total of %v == %d, want 10", s1, tot))
   256  	}
   257  }
   258  
   259  func TestFilter() {
   260  	s1 := _Make[int]()
   261  	s1.Add(1)
   262  	s1.Add(2)
   263  	s1.Add(3)
   264  	s1.Filter(func(v int) bool { return v%2 == 0 })
   265  	if vals, want := s1.Values(), []int{2}; !_SliceEqual(vals, want) {
   266  		panic(fmt.Sprintf("after Filter got %v, want %v", vals, want))
   267  	}
   268  
   269  }
   270  
   271  func main() {
   272  	TestSet()
   273  	TestEqual()
   274  	TestCopy()
   275  	TestAddSet()
   276  	TestSubSet()
   277  	TestIntersect()
   278  	TestIterate()
   279  	TestFilter()
   280  }
   281  

View as plain text