Source file src/database/sql/fakedb_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sql
     6  
     7  import (
     8  	"context"
     9  	"database/sql/driver"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"reflect"
    14  	"sort"
    15  	"strconv"
    16  	"strings"
    17  	"sync"
    18  	"testing"
    19  	"time"
    20  )
    21  
    22  // fakeDriver is a fake database that implements Go's driver.Driver
    23  // interface, just for testing.
    24  //
    25  // It speaks a query language that's semantically similar to but
    26  // syntactically different and simpler than SQL.  The syntax is as
    27  // follows:
    28  //
    29  //	WIPE
    30  //	CREATE|<tablename>|<col>=<type>,<col>=<type>,...
    31  //	  where types are: "string", [u]int{8,16,32,64}, "bool"
    32  //	INSERT|<tablename>|col=val,col2=val2,col3=?
    33  //	SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
    34  //	SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
    35  //
    36  // Any of these can be preceded by PANIC|<method>|, to cause the
    37  // named method on fakeStmt to panic.
    38  //
    39  // Any of these can be proceeded by WAIT|<duration>|, to cause the
    40  // named method on fakeStmt to sleep for the specified duration.
    41  //
    42  // Multiple of these can be combined when separated with a semicolon.
    43  //
    44  // When opening a fakeDriver's database, it starts empty with no
    45  // tables. All tables and data are stored in memory only.
    46  type fakeDriver struct {
    47  	mu         sync.Mutex // guards 3 following fields
    48  	openCount  int        // conn opens
    49  	closeCount int        // conn closes
    50  	waitCh     chan struct{}
    51  	waitingCh  chan struct{}
    52  	dbs        map[string]*fakeDB
    53  }
    54  
    55  type fakeConnector struct {
    56  	name string
    57  
    58  	waiter func(context.Context)
    59  	closed bool
    60  }
    61  
    62  func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
    63  	conn, err := fdriver.Open(c.name)
    64  	conn.(*fakeConn).waiter = c.waiter
    65  	return conn, err
    66  }
    67  
    68  func (c *fakeConnector) Driver() driver.Driver {
    69  	return fdriver
    70  }
    71  
    72  func (c *fakeConnector) Close() error {
    73  	if c.closed {
    74  		return errors.New("fakedb: connector is closed")
    75  	}
    76  	c.closed = true
    77  	return nil
    78  }
    79  
    80  type fakeDriverCtx struct {
    81  	fakeDriver
    82  }
    83  
    84  var _ driver.DriverContext = &fakeDriverCtx{}
    85  
    86  func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
    87  	return &fakeConnector{name: name}, nil
    88  }
    89  
    90  type fakeDB struct {
    91  	name string
    92  
    93  	mu       sync.Mutex
    94  	tables   map[string]*table
    95  	badConn  bool
    96  	allowAny bool
    97  }
    98  
    99  type fakeError struct {
   100  	Message string
   101  	Wrapped error
   102  }
   103  
   104  func (err fakeError) Error() string {
   105  	return err.Message
   106  }
   107  
   108  func (err fakeError) Unwrap() error {
   109  	return err.Wrapped
   110  }
   111  
   112  type table struct {
   113  	mu      sync.Mutex
   114  	colname []string
   115  	coltype []string
   116  	rows    []*row
   117  }
   118  
   119  func (t *table) columnIndex(name string) int {
   120  	for n, nname := range t.colname {
   121  		if name == nname {
   122  			return n
   123  		}
   124  	}
   125  	return -1
   126  }
   127  
   128  type row struct {
   129  	cols []any // must be same size as its table colname + coltype
   130  }
   131  
   132  type memToucher interface {
   133  	// touchMem reads & writes some memory, to help find data races.
   134  	touchMem()
   135  }
   136  
   137  type fakeConn struct {
   138  	db *fakeDB // where to return ourselves to
   139  
   140  	currTx *fakeTx
   141  
   142  	// Every operation writes to line to enable the race detector
   143  	// check for data races.
   144  	line int64
   145  
   146  	// Stats for tests:
   147  	mu          sync.Mutex
   148  	stmtsMade   int
   149  	stmtsClosed int
   150  	numPrepare  int
   151  
   152  	// bad connection tests; see isBad()
   153  	bad       bool
   154  	stickyBad bool
   155  
   156  	skipDirtySession bool // tests that use Conn should set this to true.
   157  
   158  	// dirtySession tests ResetSession, true if a query has executed
   159  	// until ResetSession is called.
   160  	dirtySession bool
   161  
   162  	// The waiter is called before each query. May be used in place of the "WAIT"
   163  	// directive.
   164  	waiter func(context.Context)
   165  }
   166  
   167  func (c *fakeConn) touchMem() {
   168  	c.line++
   169  }
   170  
   171  func (c *fakeConn) incrStat(v *int) {
   172  	c.mu.Lock()
   173  	*v++
   174  	c.mu.Unlock()
   175  }
   176  
   177  type fakeTx struct {
   178  	c *fakeConn
   179  }
   180  
   181  type boundCol struct {
   182  	Column      string
   183  	Placeholder string
   184  	Ordinal     int
   185  }
   186  
   187  type fakeStmt struct {
   188  	memToucher
   189  	c *fakeConn
   190  	q string // just for debugging
   191  
   192  	cmd   string
   193  	table string
   194  	panic string
   195  	wait  time.Duration
   196  
   197  	next *fakeStmt // used for returning multiple results.
   198  
   199  	closed bool
   200  
   201  	colName      []string // used by CREATE, INSERT, SELECT (selected columns)
   202  	colType      []string // used by CREATE
   203  	colValue     []any    // used by INSERT (mix of strings and "?" for bound params)
   204  	placeholders int      // used by INSERT/SELECT: number of ? params
   205  
   206  	whereCol []boundCol // used by SELECT (all placeholders)
   207  
   208  	placeholderConverter []driver.ValueConverter // used by INSERT
   209  }
   210  
   211  var fdriver driver.Driver = &fakeDriver{}
   212  
   213  func init() {
   214  	Register("test", fdriver)
   215  }
   216  
   217  func contains(list []string, y string) bool {
   218  	for _, x := range list {
   219  		if x == y {
   220  			return true
   221  		}
   222  	}
   223  	return false
   224  }
   225  
   226  type Dummy struct {
   227  	driver.Driver
   228  }
   229  
   230  func TestDrivers(t *testing.T) {
   231  	unregisterAllDrivers()
   232  	Register("test", fdriver)
   233  	Register("invalid", Dummy{})
   234  	all := Drivers()
   235  	if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
   236  		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
   237  	}
   238  }
   239  
   240  // hook to simulate connection failures
   241  var hookOpenErr struct {
   242  	sync.Mutex
   243  	fn func() error
   244  }
   245  
   246  func setHookOpenErr(fn func() error) {
   247  	hookOpenErr.Lock()
   248  	defer hookOpenErr.Unlock()
   249  	hookOpenErr.fn = fn
   250  }
   251  
   252  // Supports dsn forms:
   253  //
   254  //	<dbname>
   255  //	<dbname>;<opts>  (only currently supported option is `badConn`,
   256  //	                  which causes driver.ErrBadConn to be returned on
   257  //	                  every other conn.Begin())
   258  func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
   259  	hookOpenErr.Lock()
   260  	fn := hookOpenErr.fn
   261  	hookOpenErr.Unlock()
   262  	if fn != nil {
   263  		if err := fn(); err != nil {
   264  			return nil, err
   265  		}
   266  	}
   267  	parts := strings.Split(dsn, ";")
   268  	if len(parts) < 1 {
   269  		return nil, errors.New("fakedb: no database name")
   270  	}
   271  	name := parts[0]
   272  
   273  	db := d.getDB(name)
   274  
   275  	d.mu.Lock()
   276  	d.openCount++
   277  	d.mu.Unlock()
   278  	conn := &fakeConn{db: db}
   279  
   280  	if len(parts) >= 2 && parts[1] == "badConn" {
   281  		conn.bad = true
   282  	}
   283  	if d.waitCh != nil {
   284  		d.waitingCh <- struct{}{}
   285  		<-d.waitCh
   286  		d.waitCh = nil
   287  		d.waitingCh = nil
   288  	}
   289  	return conn, nil
   290  }
   291  
   292  func (d *fakeDriver) getDB(name string) *fakeDB {
   293  	d.mu.Lock()
   294  	defer d.mu.Unlock()
   295  	if d.dbs == nil {
   296  		d.dbs = make(map[string]*fakeDB)
   297  	}
   298  	db, ok := d.dbs[name]
   299  	if !ok {
   300  		db = &fakeDB{name: name}
   301  		d.dbs[name] = db
   302  	}
   303  	return db
   304  }
   305  
   306  func (db *fakeDB) wipe() {
   307  	db.mu.Lock()
   308  	defer db.mu.Unlock()
   309  	db.tables = nil
   310  }
   311  
   312  func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
   313  	db.mu.Lock()
   314  	defer db.mu.Unlock()
   315  	if db.tables == nil {
   316  		db.tables = make(map[string]*table)
   317  	}
   318  	if _, exist := db.tables[name]; exist {
   319  		return fmt.Errorf("fakedb: table %q already exists", name)
   320  	}
   321  	if len(columnNames) != len(columnTypes) {
   322  		return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d",
   323  			name, len(columnNames), len(columnTypes))
   324  	}
   325  	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
   326  	return nil
   327  }
   328  
   329  // must be called with db.mu lock held
   330  func (db *fakeDB) table(table string) (*table, bool) {
   331  	if db.tables == nil {
   332  		return nil, false
   333  	}
   334  	t, ok := db.tables[table]
   335  	return t, ok
   336  }
   337  
   338  func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
   339  	db.mu.Lock()
   340  	defer db.mu.Unlock()
   341  	t, ok := db.table(table)
   342  	if !ok {
   343  		return
   344  	}
   345  	for n, cname := range t.colname {
   346  		if cname == column {
   347  			return t.coltype[n], true
   348  		}
   349  	}
   350  	return "", false
   351  }
   352  
   353  func (c *fakeConn) isBad() bool {
   354  	if c.stickyBad {
   355  		return true
   356  	} else if c.bad {
   357  		if c.db == nil {
   358  			return false
   359  		}
   360  		// alternate between bad conn and not bad conn
   361  		c.db.badConn = !c.db.badConn
   362  		return c.db.badConn
   363  	} else {
   364  		return false
   365  	}
   366  }
   367  
   368  func (c *fakeConn) isDirtyAndMark() bool {
   369  	if c.skipDirtySession {
   370  		return false
   371  	}
   372  	if c.currTx != nil {
   373  		c.dirtySession = true
   374  		return false
   375  	}
   376  	if c.dirtySession {
   377  		return true
   378  	}
   379  	c.dirtySession = true
   380  	return false
   381  }
   382  
   383  func (c *fakeConn) Begin() (driver.Tx, error) {
   384  	if c.isBad() {
   385  		return nil, fakeError{Wrapped: driver.ErrBadConn}
   386  	}
   387  	if c.currTx != nil {
   388  		return nil, errors.New("fakedb: already in a transaction")
   389  	}
   390  	c.touchMem()
   391  	c.currTx = &fakeTx{c: c}
   392  	return c.currTx, nil
   393  }
   394  
   395  var hookPostCloseConn struct {
   396  	sync.Mutex
   397  	fn func(*fakeConn, error)
   398  }
   399  
   400  func setHookpostCloseConn(fn func(*fakeConn, error)) {
   401  	hookPostCloseConn.Lock()
   402  	defer hookPostCloseConn.Unlock()
   403  	hookPostCloseConn.fn = fn
   404  }
   405  
   406  var testStrictClose *testing.T
   407  
   408  // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
   409  // fails to close. If nil, the check is disabled.
   410  func setStrictFakeConnClose(t *testing.T) {
   411  	testStrictClose = t
   412  }
   413  
   414  func (c *fakeConn) ResetSession(ctx context.Context) error {
   415  	c.dirtySession = false
   416  	c.currTx = nil
   417  	if c.isBad() {
   418  		return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn}
   419  	}
   420  	return nil
   421  }
   422  
   423  var _ driver.Validator = (*fakeConn)(nil)
   424  
   425  func (c *fakeConn) IsValid() bool {
   426  	return !c.isBad()
   427  }
   428  
   429  func (c *fakeConn) Close() (err error) {
   430  	drv := fdriver.(*fakeDriver)
   431  	defer func() {
   432  		if err != nil && testStrictClose != nil {
   433  			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
   434  		}
   435  		hookPostCloseConn.Lock()
   436  		fn := hookPostCloseConn.fn
   437  		hookPostCloseConn.Unlock()
   438  		if fn != nil {
   439  			fn(c, err)
   440  		}
   441  		if err == nil {
   442  			drv.mu.Lock()
   443  			drv.closeCount++
   444  			drv.mu.Unlock()
   445  		}
   446  	}()
   447  	c.touchMem()
   448  	if c.currTx != nil {
   449  		return errors.New("fakedb: can't close fakeConn; in a Transaction")
   450  	}
   451  	if c.db == nil {
   452  		return errors.New("fakedb: can't close fakeConn; already closed")
   453  	}
   454  	if c.stmtsMade > c.stmtsClosed {
   455  		return errors.New("fakedb: can't close; dangling statement(s)")
   456  	}
   457  	c.db = nil
   458  	return nil
   459  }
   460  
   461  func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
   462  	for _, arg := range args {
   463  		switch arg.Value.(type) {
   464  		case int64, float64, bool, nil, []byte, string, time.Time:
   465  		default:
   466  			if !allowAny {
   467  				return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
   468  			}
   469  		}
   470  	}
   471  	return nil
   472  }
   473  
   474  func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
   475  	// Ensure that ExecContext is called if available.
   476  	panic("ExecContext was not called.")
   477  }
   478  
   479  func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   480  	// This is an optional interface, but it's implemented here
   481  	// just to check that all the args are of the proper types.
   482  	// ErrSkip is returned so the caller acts as if we didn't
   483  	// implement this at all.
   484  	err := checkSubsetTypes(c.db.allowAny, args)
   485  	if err != nil {
   486  		return nil, err
   487  	}
   488  	return nil, driver.ErrSkip
   489  }
   490  
   491  func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
   492  	// Ensure that ExecContext is called if available.
   493  	panic("QueryContext was not called.")
   494  }
   495  
   496  func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   497  	// This is an optional interface, but it's implemented here
   498  	// just to check that all the args are of the proper types.
   499  	// ErrSkip is returned so the caller acts as if we didn't
   500  	// implement this at all.
   501  	err := checkSubsetTypes(c.db.allowAny, args)
   502  	if err != nil {
   503  		return nil, err
   504  	}
   505  	return nil, driver.ErrSkip
   506  }
   507  
   508  func errf(msg string, args ...any) error {
   509  	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
   510  }
   511  
   512  // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
   513  // (note that where columns must always contain ? marks,
   514  // just a limitation for fakedb)
   515  func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   516  	if len(parts) != 3 {
   517  		stmt.Close()
   518  		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
   519  	}
   520  	stmt.table = parts[0]
   521  
   522  	stmt.colName = strings.Split(parts[1], ",")
   523  	for n, colspec := range strings.Split(parts[2], ",") {
   524  		if colspec == "" {
   525  			continue
   526  		}
   527  		nameVal := strings.Split(colspec, "=")
   528  		if len(nameVal) != 2 {
   529  			stmt.Close()
   530  			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   531  		}
   532  		column, value := nameVal[0], nameVal[1]
   533  		_, ok := c.db.columnType(stmt.table, column)
   534  		if !ok {
   535  			stmt.Close()
   536  			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
   537  		}
   538  		if !strings.HasPrefix(value, "?") {
   539  			stmt.Close()
   540  			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
   541  				stmt.table, column)
   542  		}
   543  		stmt.placeholders++
   544  		stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
   545  	}
   546  	return stmt, nil
   547  }
   548  
   549  // parts are table|col=type,col2=type2
   550  func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   551  	if len(parts) != 2 {
   552  		stmt.Close()
   553  		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
   554  	}
   555  	stmt.table = parts[0]
   556  	for n, colspec := range strings.Split(parts[1], ",") {
   557  		nameType := strings.Split(colspec, "=")
   558  		if len(nameType) != 2 {
   559  			stmt.Close()
   560  			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   561  		}
   562  		stmt.colName = append(stmt.colName, nameType[0])
   563  		stmt.colType = append(stmt.colType, nameType[1])
   564  	}
   565  	return stmt, nil
   566  }
   567  
   568  // parts are table|col=?,col2=val
   569  func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   570  	if len(parts) != 2 {
   571  		stmt.Close()
   572  		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
   573  	}
   574  	stmt.table = parts[0]
   575  	for n, colspec := range strings.Split(parts[1], ",") {
   576  		nameVal := strings.Split(colspec, "=")
   577  		if len(nameVal) != 2 {
   578  			stmt.Close()
   579  			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   580  		}
   581  		column, value := nameVal[0], nameVal[1]
   582  		ctype, ok := c.db.columnType(stmt.table, column)
   583  		if !ok {
   584  			stmt.Close()
   585  			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
   586  		}
   587  		stmt.colName = append(stmt.colName, column)
   588  
   589  		if !strings.HasPrefix(value, "?") {
   590  			var subsetVal any
   591  			// Convert to driver subset type
   592  			switch ctype {
   593  			case "string":
   594  				subsetVal = []byte(value)
   595  			case "blob":
   596  				subsetVal = []byte(value)
   597  			case "int32":
   598  				i, err := strconv.Atoi(value)
   599  				if err != nil {
   600  					stmt.Close()
   601  					return nil, errf("invalid conversion to int32 from %q", value)
   602  				}
   603  				subsetVal = int64(i) // int64 is a subset type, but not int32
   604  			case "table": // For testing cursor reads.
   605  				c.skipDirtySession = true
   606  				vparts := strings.Split(value, "!")
   607  
   608  				substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
   609  				if err != nil {
   610  					return nil, err
   611  				}
   612  				cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
   613  				substmt.Close()
   614  				if err != nil {
   615  					return nil, err
   616  				}
   617  				subsetVal = cursor
   618  			default:
   619  				stmt.Close()
   620  				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
   621  			}
   622  			stmt.colValue = append(stmt.colValue, subsetVal)
   623  		} else {
   624  			stmt.placeholders++
   625  			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
   626  			stmt.colValue = append(stmt.colValue, value)
   627  		}
   628  	}
   629  	return stmt, nil
   630  }
   631  
   632  // hook to simulate broken connections
   633  var hookPrepareBadConn func() bool
   634  
   635  func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
   636  	panic("use PrepareContext")
   637  }
   638  
   639  func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   640  	c.numPrepare++
   641  	if c.db == nil {
   642  		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
   643  	}
   644  
   645  	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
   646  		return nil, fakeError{Message: "Prepare: Sticky Bad", Wrapped: driver.ErrBadConn}
   647  	}
   648  
   649  	c.touchMem()
   650  	var firstStmt, prev *fakeStmt
   651  	for _, query := range strings.Split(query, ";") {
   652  		parts := strings.Split(query, "|")
   653  		if len(parts) < 1 {
   654  			return nil, errf("empty query")
   655  		}
   656  		stmt := &fakeStmt{q: query, c: c, memToucher: c}
   657  		if firstStmt == nil {
   658  			firstStmt = stmt
   659  		}
   660  		if len(parts) >= 3 {
   661  			switch parts[0] {
   662  			case "PANIC":
   663  				stmt.panic = parts[1]
   664  				parts = parts[2:]
   665  			case "WAIT":
   666  				wait, err := time.ParseDuration(parts[1])
   667  				if err != nil {
   668  					return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
   669  				}
   670  				parts = parts[2:]
   671  				stmt.wait = wait
   672  			}
   673  		}
   674  		cmd := parts[0]
   675  		stmt.cmd = cmd
   676  		parts = parts[1:]
   677  
   678  		if c.waiter != nil {
   679  			c.waiter(ctx)
   680  			if err := ctx.Err(); err != nil {
   681  				return nil, err
   682  			}
   683  		}
   684  
   685  		if stmt.wait > 0 {
   686  			wait := time.NewTimer(stmt.wait)
   687  			select {
   688  			case <-wait.C:
   689  			case <-ctx.Done():
   690  				wait.Stop()
   691  				return nil, ctx.Err()
   692  			}
   693  		}
   694  
   695  		c.incrStat(&c.stmtsMade)
   696  		var err error
   697  		switch cmd {
   698  		case "WIPE":
   699  			// Nothing
   700  		case "SELECT":
   701  			stmt, err = c.prepareSelect(stmt, parts)
   702  		case "CREATE":
   703  			stmt, err = c.prepareCreate(stmt, parts)
   704  		case "INSERT":
   705  			stmt, err = c.prepareInsert(ctx, stmt, parts)
   706  		case "NOSERT":
   707  			// Do all the prep-work like for an INSERT but don't actually insert the row.
   708  			// Used for some of the concurrent tests.
   709  			stmt, err = c.prepareInsert(ctx, stmt, parts)
   710  		default:
   711  			stmt.Close()
   712  			return nil, errf("unsupported command type %q", cmd)
   713  		}
   714  		if err != nil {
   715  			return nil, err
   716  		}
   717  		if prev != nil {
   718  			prev.next = stmt
   719  		}
   720  		prev = stmt
   721  	}
   722  	return firstStmt, nil
   723  }
   724  
   725  func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
   726  	if s.panic == "ColumnConverter" {
   727  		panic(s.panic)
   728  	}
   729  	if len(s.placeholderConverter) == 0 {
   730  		return driver.DefaultParameterConverter
   731  	}
   732  	return s.placeholderConverter[idx]
   733  }
   734  
   735  func (s *fakeStmt) Close() error {
   736  	if s.panic == "Close" {
   737  		panic(s.panic)
   738  	}
   739  	if s.c == nil {
   740  		panic("nil conn in fakeStmt.Close")
   741  	}
   742  	if s.c.db == nil {
   743  		panic("in fakeStmt.Close, conn's db is nil (already closed)")
   744  	}
   745  	s.touchMem()
   746  	if !s.closed {
   747  		s.c.incrStat(&s.c.stmtsClosed)
   748  		s.closed = true
   749  	}
   750  	if s.next != nil {
   751  		s.next.Close()
   752  	}
   753  	return nil
   754  }
   755  
   756  var errClosed = errors.New("fakedb: statement has been closed")
   757  
   758  // hook to simulate broken connections
   759  var hookExecBadConn func() bool
   760  
   761  func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
   762  	panic("Using ExecContext")
   763  }
   764  
   765  var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
   766  
   767  func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   768  	if s.panic == "Exec" {
   769  		panic(s.panic)
   770  	}
   771  	if s.closed {
   772  		return nil, errClosed
   773  	}
   774  
   775  	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
   776  		return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn}
   777  	}
   778  	if s.c.isDirtyAndMark() {
   779  		return nil, errFakeConnSessionDirty
   780  	}
   781  
   782  	err := checkSubsetTypes(s.c.db.allowAny, args)
   783  	if err != nil {
   784  		return nil, err
   785  	}
   786  	s.touchMem()
   787  
   788  	if s.wait > 0 {
   789  		time.Sleep(s.wait)
   790  	}
   791  
   792  	select {
   793  	default:
   794  	case <-ctx.Done():
   795  		return nil, ctx.Err()
   796  	}
   797  
   798  	db := s.c.db
   799  	switch s.cmd {
   800  	case "WIPE":
   801  		db.wipe()
   802  		return driver.ResultNoRows, nil
   803  	case "CREATE":
   804  		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
   805  			return nil, err
   806  		}
   807  		return driver.ResultNoRows, nil
   808  	case "INSERT":
   809  		return s.execInsert(args, true)
   810  	case "NOSERT":
   811  		// Do all the prep-work like for an INSERT but don't actually insert the row.
   812  		// Used for some of the concurrent tests.
   813  		return s.execInsert(args, false)
   814  	}
   815  	return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
   816  }
   817  
   818  // When doInsert is true, add the row to the table.
   819  // When doInsert is false do prep-work and error checking, but don't
   820  // actually add the row to the table.
   821  func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
   822  	db := s.c.db
   823  	if len(args) != s.placeholders {
   824  		panic("error in pkg db; should only get here if size is correct")
   825  	}
   826  	db.mu.Lock()
   827  	t, ok := db.table(s.table)
   828  	db.mu.Unlock()
   829  	if !ok {
   830  		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   831  	}
   832  
   833  	t.mu.Lock()
   834  	defer t.mu.Unlock()
   835  
   836  	var cols []any
   837  	if doInsert {
   838  		cols = make([]any, len(t.colname))
   839  	}
   840  	argPos := 0
   841  	for n, colname := range s.colName {
   842  		colidx := t.columnIndex(colname)
   843  		if colidx == -1 {
   844  			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
   845  		}
   846  		var val any
   847  		if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
   848  			if strvalue == "?" {
   849  				val = args[argPos].Value
   850  			} else {
   851  				// Assign value from argument placeholder name.
   852  				for _, a := range args {
   853  					if a.Name == strvalue[1:] {
   854  						val = a.Value
   855  						break
   856  					}
   857  				}
   858  			}
   859  			argPos++
   860  		} else {
   861  			val = s.colValue[n]
   862  		}
   863  		if doInsert {
   864  			cols[colidx] = val
   865  		}
   866  	}
   867  
   868  	if doInsert {
   869  		t.rows = append(t.rows, &row{cols: cols})
   870  	}
   871  	return driver.RowsAffected(1), nil
   872  }
   873  
   874  // hook to simulate broken connections
   875  var hookQueryBadConn func() bool
   876  
   877  func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
   878  	panic("Use QueryContext")
   879  }
   880  
   881  func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   882  	if s.panic == "Query" {
   883  		panic(s.panic)
   884  	}
   885  	if s.closed {
   886  		return nil, errClosed
   887  	}
   888  
   889  	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
   890  		return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn}
   891  	}
   892  	if s.c.isDirtyAndMark() {
   893  		return nil, errFakeConnSessionDirty
   894  	}
   895  
   896  	err := checkSubsetTypes(s.c.db.allowAny, args)
   897  	if err != nil {
   898  		return nil, err
   899  	}
   900  
   901  	s.touchMem()
   902  	db := s.c.db
   903  	if len(args) != s.placeholders {
   904  		panic("error in pkg db; should only get here if size is correct")
   905  	}
   906  
   907  	setMRows := make([][]*row, 0, 1)
   908  	setColumns := make([][]string, 0, 1)
   909  	setColType := make([][]string, 0, 1)
   910  
   911  	for {
   912  		db.mu.Lock()
   913  		t, ok := db.table(s.table)
   914  		db.mu.Unlock()
   915  		if !ok {
   916  			return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   917  		}
   918  
   919  		if s.table == "magicquery" {
   920  			if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
   921  				if args[0].Value == "sleep" {
   922  					time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
   923  				}
   924  			}
   925  		}
   926  		if s.table == "tx_status" && s.colName[0] == "tx_status" {
   927  			txStatus := "autocommit"
   928  			if s.c.currTx != nil {
   929  				txStatus = "transaction"
   930  			}
   931  			cursor := &rowsCursor{
   932  				parentMem: s.c,
   933  				posRow:    -1,
   934  				rows: [][]*row{
   935  					{
   936  						{
   937  							cols: []any{
   938  								txStatus,
   939  							},
   940  						},
   941  					},
   942  				},
   943  				cols: [][]string{
   944  					{
   945  						"tx_status",
   946  					},
   947  				},
   948  				colType: [][]string{
   949  					{
   950  						"string",
   951  					},
   952  				},
   953  				errPos: -1,
   954  			}
   955  			return cursor, nil
   956  		}
   957  
   958  		t.mu.Lock()
   959  
   960  		colIdx := make(map[string]int) // select column name -> column index in table
   961  		for _, name := range s.colName {
   962  			idx := t.columnIndex(name)
   963  			if idx == -1 {
   964  				t.mu.Unlock()
   965  				return nil, fmt.Errorf("fakedb: unknown column name %q", name)
   966  			}
   967  			colIdx[name] = idx
   968  		}
   969  
   970  		mrows := []*row{}
   971  	rows:
   972  		for _, trow := range t.rows {
   973  			// Process the where clause, skipping non-match rows. This is lazy
   974  			// and just uses fmt.Sprintf("%v") to test equality. Good enough
   975  			// for test code.
   976  			for _, wcol := range s.whereCol {
   977  				idx := t.columnIndex(wcol.Column)
   978  				if idx == -1 {
   979  					t.mu.Unlock()
   980  					return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol)
   981  				}
   982  				tcol := trow.cols[idx]
   983  				if bs, ok := tcol.([]byte); ok {
   984  					// lazy hack to avoid sprintf %v on a []byte
   985  					tcol = string(bs)
   986  				}
   987  				var argValue any
   988  				if wcol.Placeholder == "?" {
   989  					argValue = args[wcol.Ordinal-1].Value
   990  				} else {
   991  					// Assign arg value from placeholder name.
   992  					for _, a := range args {
   993  						if a.Name == wcol.Placeholder[1:] {
   994  							argValue = a.Value
   995  							break
   996  						}
   997  					}
   998  				}
   999  				if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
  1000  					continue rows
  1001  				}
  1002  			}
  1003  			mrow := &row{cols: make([]any, len(s.colName))}
  1004  			for seli, name := range s.colName {
  1005  				mrow.cols[seli] = trow.cols[colIdx[name]]
  1006  			}
  1007  			mrows = append(mrows, mrow)
  1008  		}
  1009  
  1010  		var colType []string
  1011  		for _, column := range s.colName {
  1012  			colType = append(colType, t.coltype[t.columnIndex(column)])
  1013  		}
  1014  
  1015  		t.mu.Unlock()
  1016  
  1017  		setMRows = append(setMRows, mrows)
  1018  		setColumns = append(setColumns, s.colName)
  1019  		setColType = append(setColType, colType)
  1020  
  1021  		if s.next == nil {
  1022  			break
  1023  		}
  1024  		s = s.next
  1025  	}
  1026  
  1027  	cursor := &rowsCursor{
  1028  		parentMem: s.c,
  1029  		posRow:    -1,
  1030  		rows:      setMRows,
  1031  		cols:      setColumns,
  1032  		colType:   setColType,
  1033  		errPos:    -1,
  1034  	}
  1035  	return cursor, nil
  1036  }
  1037  
  1038  func (s *fakeStmt) NumInput() int {
  1039  	if s.panic == "NumInput" {
  1040  		panic(s.panic)
  1041  	}
  1042  	return s.placeholders
  1043  }
  1044  
  1045  // hook to simulate broken connections
  1046  var hookCommitBadConn func() bool
  1047  
  1048  func (tx *fakeTx) Commit() error {
  1049  	tx.c.currTx = nil
  1050  	if hookCommitBadConn != nil && hookCommitBadConn() {
  1051  		return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn}
  1052  	}
  1053  	tx.c.touchMem()
  1054  	return nil
  1055  }
  1056  
  1057  // hook to simulate broken connections
  1058  var hookRollbackBadConn func() bool
  1059  
  1060  func (tx *fakeTx) Rollback() error {
  1061  	tx.c.currTx = nil
  1062  	if hookRollbackBadConn != nil && hookRollbackBadConn() {
  1063  		return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn}
  1064  	}
  1065  	tx.c.touchMem()
  1066  	return nil
  1067  }
  1068  
  1069  type rowsCursor struct {
  1070  	parentMem memToucher
  1071  	cols      [][]string
  1072  	colType   [][]string
  1073  	posSet    int
  1074  	posRow    int
  1075  	rows      [][]*row
  1076  	closed    bool
  1077  
  1078  	// errPos and err are for making Next return early with error.
  1079  	errPos int
  1080  	err    error
  1081  
  1082  	// a clone of slices to give out to clients, indexed by the
  1083  	// original slice's first byte address.  we clone them
  1084  	// just so we're able to corrupt them on close.
  1085  	bytesClone map[*byte][]byte
  1086  
  1087  	// Every operation writes to line to enable the race detector
  1088  	// check for data races.
  1089  	// This is separate from the fakeConn.line to allow for drivers that
  1090  	// can start multiple queries on the same transaction at the same time.
  1091  	line int64
  1092  
  1093  	// closeErr is returned when rowsCursor.Close
  1094  	closeErr error
  1095  }
  1096  
  1097  func (rc *rowsCursor) touchMem() {
  1098  	rc.parentMem.touchMem()
  1099  	rc.line++
  1100  }
  1101  
  1102  func (rc *rowsCursor) Close() error {
  1103  	rc.touchMem()
  1104  	rc.parentMem.touchMem()
  1105  	rc.closed = true
  1106  	return rc.closeErr
  1107  }
  1108  
  1109  func (rc *rowsCursor) Columns() []string {
  1110  	return rc.cols[rc.posSet]
  1111  }
  1112  
  1113  func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
  1114  	return colTypeToReflectType(rc.colType[rc.posSet][index])
  1115  }
  1116  
  1117  var rowsCursorNextHook func(dest []driver.Value) error
  1118  
  1119  func (rc *rowsCursor) Next(dest []driver.Value) error {
  1120  	if rowsCursorNextHook != nil {
  1121  		return rowsCursorNextHook(dest)
  1122  	}
  1123  
  1124  	if rc.closed {
  1125  		return errors.New("fakedb: cursor is closed")
  1126  	}
  1127  	rc.touchMem()
  1128  	rc.posRow++
  1129  	if rc.posRow == rc.errPos {
  1130  		return rc.err
  1131  	}
  1132  	if rc.posRow >= len(rc.rows[rc.posSet]) {
  1133  		return io.EOF // per interface spec
  1134  	}
  1135  	for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
  1136  		// TODO(bradfitz): convert to subset types? naah, I
  1137  		// think the subset types should only be input to
  1138  		// driver, but the sql package should be able to handle
  1139  		// a wider range of types coming out of drivers. all
  1140  		// for ease of drivers, and to prevent drivers from
  1141  		// messing up conversions or doing them differently.
  1142  		dest[i] = v
  1143  
  1144  		if bs, ok := v.([]byte); ok {
  1145  			if rc.bytesClone == nil {
  1146  				rc.bytesClone = make(map[*byte][]byte)
  1147  			}
  1148  			clone, ok := rc.bytesClone[&bs[0]]
  1149  			if !ok {
  1150  				clone = make([]byte, len(bs))
  1151  				copy(clone, bs)
  1152  				rc.bytesClone[&bs[0]] = clone
  1153  			}
  1154  			dest[i] = clone
  1155  		}
  1156  	}
  1157  	return nil
  1158  }
  1159  
  1160  func (rc *rowsCursor) HasNextResultSet() bool {
  1161  	rc.touchMem()
  1162  	return rc.posSet < len(rc.rows)-1
  1163  }
  1164  
  1165  func (rc *rowsCursor) NextResultSet() error {
  1166  	rc.touchMem()
  1167  	if rc.HasNextResultSet() {
  1168  		rc.posSet++
  1169  		rc.posRow = -1
  1170  		return nil
  1171  	}
  1172  	return io.EOF // Per interface spec.
  1173  }
  1174  
  1175  // fakeDriverString is like driver.String, but indirects pointers like
  1176  // DefaultValueConverter.
  1177  //
  1178  // This could be surprising behavior to retroactively apply to
  1179  // driver.String now that Go1 is out, but this is convenient for
  1180  // our TestPointerParamsAndScans.
  1181  type fakeDriverString struct{}
  1182  
  1183  func (fakeDriverString) ConvertValue(v any) (driver.Value, error) {
  1184  	switch c := v.(type) {
  1185  	case string, []byte:
  1186  		return v, nil
  1187  	case *string:
  1188  		if c == nil {
  1189  			return nil, nil
  1190  		}
  1191  		return *c, nil
  1192  	}
  1193  	return fmt.Sprintf("%v", v), nil
  1194  }
  1195  
  1196  type anyTypeConverter struct{}
  1197  
  1198  func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) {
  1199  	return v, nil
  1200  }
  1201  
  1202  func converterForType(typ string) driver.ValueConverter {
  1203  	switch typ {
  1204  	case "bool":
  1205  		return driver.Bool
  1206  	case "nullbool":
  1207  		return driver.Null{Converter: driver.Bool}
  1208  	case "byte", "int16":
  1209  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1210  	case "int32":
  1211  		return driver.Int32
  1212  	case "nullbyte", "nullint32", "nullint16":
  1213  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1214  	case "string":
  1215  		return driver.NotNull{Converter: fakeDriverString{}}
  1216  	case "nullstring":
  1217  		return driver.Null{Converter: fakeDriverString{}}
  1218  	case "int64":
  1219  		// TODO(coopernurse): add type-specific converter
  1220  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1221  	case "nullint64":
  1222  		// TODO(coopernurse): add type-specific converter
  1223  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1224  	case "float64":
  1225  		// TODO(coopernurse): add type-specific converter
  1226  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1227  	case "nullfloat64":
  1228  		// TODO(coopernurse): add type-specific converter
  1229  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1230  	case "datetime":
  1231  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1232  	case "nulldatetime":
  1233  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1234  	case "any":
  1235  		return anyTypeConverter{}
  1236  	}
  1237  	panic("invalid fakedb column type of " + typ)
  1238  }
  1239  
  1240  func colTypeToReflectType(typ string) reflect.Type {
  1241  	switch typ {
  1242  	case "bool":
  1243  		return reflect.TypeOf(false)
  1244  	case "nullbool":
  1245  		return reflect.TypeOf(NullBool{})
  1246  	case "int16":
  1247  		return reflect.TypeOf(int16(0))
  1248  	case "nullint16":
  1249  		return reflect.TypeOf(NullInt16{})
  1250  	case "int32":
  1251  		return reflect.TypeOf(int32(0))
  1252  	case "nullint32":
  1253  		return reflect.TypeOf(NullInt32{})
  1254  	case "string":
  1255  		return reflect.TypeOf("")
  1256  	case "nullstring":
  1257  		return reflect.TypeOf(NullString{})
  1258  	case "int64":
  1259  		return reflect.TypeOf(int64(0))
  1260  	case "nullint64":
  1261  		return reflect.TypeOf(NullInt64{})
  1262  	case "float64":
  1263  		return reflect.TypeOf(float64(0))
  1264  	case "nullfloat64":
  1265  		return reflect.TypeOf(NullFloat64{})
  1266  	case "datetime":
  1267  		return reflect.TypeOf(time.Time{})
  1268  	case "any":
  1269  		return reflect.TypeOf(new(any)).Elem()
  1270  	}
  1271  	panic("invalid fakedb column type of " + typ)
  1272  }
  1273  

View as plain text