Source file
src/database/sql/fakedb_test.go
1
2
3
4
5 package sql
6
7 import (
8 "context"
9 "database/sql/driver"
10 "errors"
11 "fmt"
12 "io"
13 "reflect"
14 "sort"
15 "strconv"
16 "strings"
17 "sync"
18 "testing"
19 "time"
20 )
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46 type fakeDriver struct {
47 mu sync.Mutex
48 openCount int
49 closeCount int
50 waitCh chan struct{}
51 waitingCh chan struct{}
52 dbs map[string]*fakeDB
53 }
54
55 type fakeConnector struct {
56 name string
57
58 waiter func(context.Context)
59 closed bool
60 }
61
62 func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
63 conn, err := fdriver.Open(c.name)
64 conn.(*fakeConn).waiter = c.waiter
65 return conn, err
66 }
67
68 func (c *fakeConnector) Driver() driver.Driver {
69 return fdriver
70 }
71
72 func (c *fakeConnector) Close() error {
73 if c.closed {
74 return errors.New("fakedb: connector is closed")
75 }
76 c.closed = true
77 return nil
78 }
79
80 type fakeDriverCtx struct {
81 fakeDriver
82 }
83
84 var _ driver.DriverContext = &fakeDriverCtx{}
85
86 func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
87 return &fakeConnector{name: name}, nil
88 }
89
90 type fakeDB struct {
91 name string
92
93 mu sync.Mutex
94 tables map[string]*table
95 badConn bool
96 allowAny bool
97 }
98
99 type fakeError struct {
100 Message string
101 Wrapped error
102 }
103
104 func (err fakeError) Error() string {
105 return err.Message
106 }
107
108 func (err fakeError) Unwrap() error {
109 return err.Wrapped
110 }
111
112 type table struct {
113 mu sync.Mutex
114 colname []string
115 coltype []string
116 rows []*row
117 }
118
119 func (t *table) columnIndex(name string) int {
120 for n, nname := range t.colname {
121 if name == nname {
122 return n
123 }
124 }
125 return -1
126 }
127
128 type row struct {
129 cols []any
130 }
131
132 type memToucher interface {
133
134 touchMem()
135 }
136
137 type fakeConn struct {
138 db *fakeDB
139
140 currTx *fakeTx
141
142
143
144 line int64
145
146
147 mu sync.Mutex
148 stmtsMade int
149 stmtsClosed int
150 numPrepare int
151
152
153 bad bool
154 stickyBad bool
155
156 skipDirtySession bool
157
158
159
160 dirtySession bool
161
162
163
164 waiter func(context.Context)
165 }
166
167 func (c *fakeConn) touchMem() {
168 c.line++
169 }
170
171 func (c *fakeConn) incrStat(v *int) {
172 c.mu.Lock()
173 *v++
174 c.mu.Unlock()
175 }
176
177 type fakeTx struct {
178 c *fakeConn
179 }
180
181 type boundCol struct {
182 Column string
183 Placeholder string
184 Ordinal int
185 }
186
187 type fakeStmt struct {
188 memToucher
189 c *fakeConn
190 q string
191
192 cmd string
193 table string
194 panic string
195 wait time.Duration
196
197 next *fakeStmt
198
199 closed bool
200
201 colName []string
202 colType []string
203 colValue []any
204 placeholders int
205
206 whereCol []boundCol
207
208 placeholderConverter []driver.ValueConverter
209 }
210
211 var fdriver driver.Driver = &fakeDriver{}
212
213 func init() {
214 Register("test", fdriver)
215 }
216
217 func contains(list []string, y string) bool {
218 for _, x := range list {
219 if x == y {
220 return true
221 }
222 }
223 return false
224 }
225
226 type Dummy struct {
227 driver.Driver
228 }
229
230 func TestDrivers(t *testing.T) {
231 unregisterAllDrivers()
232 Register("test", fdriver)
233 Register("invalid", Dummy{})
234 all := Drivers()
235 if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
236 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
237 }
238 }
239
240
241 var hookOpenErr struct {
242 sync.Mutex
243 fn func() error
244 }
245
246 func setHookOpenErr(fn func() error) {
247 hookOpenErr.Lock()
248 defer hookOpenErr.Unlock()
249 hookOpenErr.fn = fn
250 }
251
252
253
254
255
256
257
258 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
259 hookOpenErr.Lock()
260 fn := hookOpenErr.fn
261 hookOpenErr.Unlock()
262 if fn != nil {
263 if err := fn(); err != nil {
264 return nil, err
265 }
266 }
267 parts := strings.Split(dsn, ";")
268 if len(parts) < 1 {
269 return nil, errors.New("fakedb: no database name")
270 }
271 name := parts[0]
272
273 db := d.getDB(name)
274
275 d.mu.Lock()
276 d.openCount++
277 d.mu.Unlock()
278 conn := &fakeConn{db: db}
279
280 if len(parts) >= 2 && parts[1] == "badConn" {
281 conn.bad = true
282 }
283 if d.waitCh != nil {
284 d.waitingCh <- struct{}{}
285 <-d.waitCh
286 d.waitCh = nil
287 d.waitingCh = nil
288 }
289 return conn, nil
290 }
291
292 func (d *fakeDriver) getDB(name string) *fakeDB {
293 d.mu.Lock()
294 defer d.mu.Unlock()
295 if d.dbs == nil {
296 d.dbs = make(map[string]*fakeDB)
297 }
298 db, ok := d.dbs[name]
299 if !ok {
300 db = &fakeDB{name: name}
301 d.dbs[name] = db
302 }
303 return db
304 }
305
306 func (db *fakeDB) wipe() {
307 db.mu.Lock()
308 defer db.mu.Unlock()
309 db.tables = nil
310 }
311
312 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
313 db.mu.Lock()
314 defer db.mu.Unlock()
315 if db.tables == nil {
316 db.tables = make(map[string]*table)
317 }
318 if _, exist := db.tables[name]; exist {
319 return fmt.Errorf("fakedb: table %q already exists", name)
320 }
321 if len(columnNames) != len(columnTypes) {
322 return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d",
323 name, len(columnNames), len(columnTypes))
324 }
325 db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
326 return nil
327 }
328
329
330 func (db *fakeDB) table(table string) (*table, bool) {
331 if db.tables == nil {
332 return nil, false
333 }
334 t, ok := db.tables[table]
335 return t, ok
336 }
337
338 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
339 db.mu.Lock()
340 defer db.mu.Unlock()
341 t, ok := db.table(table)
342 if !ok {
343 return
344 }
345 for n, cname := range t.colname {
346 if cname == column {
347 return t.coltype[n], true
348 }
349 }
350 return "", false
351 }
352
353 func (c *fakeConn) isBad() bool {
354 if c.stickyBad {
355 return true
356 } else if c.bad {
357 if c.db == nil {
358 return false
359 }
360
361 c.db.badConn = !c.db.badConn
362 return c.db.badConn
363 } else {
364 return false
365 }
366 }
367
368 func (c *fakeConn) isDirtyAndMark() bool {
369 if c.skipDirtySession {
370 return false
371 }
372 if c.currTx != nil {
373 c.dirtySession = true
374 return false
375 }
376 if c.dirtySession {
377 return true
378 }
379 c.dirtySession = true
380 return false
381 }
382
383 func (c *fakeConn) Begin() (driver.Tx, error) {
384 if c.isBad() {
385 return nil, fakeError{Wrapped: driver.ErrBadConn}
386 }
387 if c.currTx != nil {
388 return nil, errors.New("fakedb: already in a transaction")
389 }
390 c.touchMem()
391 c.currTx = &fakeTx{c: c}
392 return c.currTx, nil
393 }
394
395 var hookPostCloseConn struct {
396 sync.Mutex
397 fn func(*fakeConn, error)
398 }
399
400 func setHookpostCloseConn(fn func(*fakeConn, error)) {
401 hookPostCloseConn.Lock()
402 defer hookPostCloseConn.Unlock()
403 hookPostCloseConn.fn = fn
404 }
405
406 var testStrictClose *testing.T
407
408
409
410 func setStrictFakeConnClose(t *testing.T) {
411 testStrictClose = t
412 }
413
414 func (c *fakeConn) ResetSession(ctx context.Context) error {
415 c.dirtySession = false
416 c.currTx = nil
417 if c.isBad() {
418 return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn}
419 }
420 return nil
421 }
422
423 var _ driver.Validator = (*fakeConn)(nil)
424
425 func (c *fakeConn) IsValid() bool {
426 return !c.isBad()
427 }
428
429 func (c *fakeConn) Close() (err error) {
430 drv := fdriver.(*fakeDriver)
431 defer func() {
432 if err != nil && testStrictClose != nil {
433 testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
434 }
435 hookPostCloseConn.Lock()
436 fn := hookPostCloseConn.fn
437 hookPostCloseConn.Unlock()
438 if fn != nil {
439 fn(c, err)
440 }
441 if err == nil {
442 drv.mu.Lock()
443 drv.closeCount++
444 drv.mu.Unlock()
445 }
446 }()
447 c.touchMem()
448 if c.currTx != nil {
449 return errors.New("fakedb: can't close fakeConn; in a Transaction")
450 }
451 if c.db == nil {
452 return errors.New("fakedb: can't close fakeConn; already closed")
453 }
454 if c.stmtsMade > c.stmtsClosed {
455 return errors.New("fakedb: can't close; dangling statement(s)")
456 }
457 c.db = nil
458 return nil
459 }
460
461 func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
462 for _, arg := range args {
463 switch arg.Value.(type) {
464 case int64, float64, bool, nil, []byte, string, time.Time:
465 default:
466 if !allowAny {
467 return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
468 }
469 }
470 }
471 return nil
472 }
473
474 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
475
476 panic("ExecContext was not called.")
477 }
478
479 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
480
481
482
483
484 err := checkSubsetTypes(c.db.allowAny, args)
485 if err != nil {
486 return nil, err
487 }
488 return nil, driver.ErrSkip
489 }
490
491 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
492
493 panic("QueryContext was not called.")
494 }
495
496 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
497
498
499
500
501 err := checkSubsetTypes(c.db.allowAny, args)
502 if err != nil {
503 return nil, err
504 }
505 return nil, driver.ErrSkip
506 }
507
508 func errf(msg string, args ...any) error {
509 return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
510 }
511
512
513
514
515 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
516 if len(parts) != 3 {
517 stmt.Close()
518 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
519 }
520 stmt.table = parts[0]
521
522 stmt.colName = strings.Split(parts[1], ",")
523 for n, colspec := range strings.Split(parts[2], ",") {
524 if colspec == "" {
525 continue
526 }
527 nameVal := strings.Split(colspec, "=")
528 if len(nameVal) != 2 {
529 stmt.Close()
530 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
531 }
532 column, value := nameVal[0], nameVal[1]
533 _, ok := c.db.columnType(stmt.table, column)
534 if !ok {
535 stmt.Close()
536 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
537 }
538 if !strings.HasPrefix(value, "?") {
539 stmt.Close()
540 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
541 stmt.table, column)
542 }
543 stmt.placeholders++
544 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
545 }
546 return stmt, nil
547 }
548
549
550 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
551 if len(parts) != 2 {
552 stmt.Close()
553 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
554 }
555 stmt.table = parts[0]
556 for n, colspec := range strings.Split(parts[1], ",") {
557 nameType := strings.Split(colspec, "=")
558 if len(nameType) != 2 {
559 stmt.Close()
560 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
561 }
562 stmt.colName = append(stmt.colName, nameType[0])
563 stmt.colType = append(stmt.colType, nameType[1])
564 }
565 return stmt, nil
566 }
567
568
569 func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
570 if len(parts) != 2 {
571 stmt.Close()
572 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
573 }
574 stmt.table = parts[0]
575 for n, colspec := range strings.Split(parts[1], ",") {
576 nameVal := strings.Split(colspec, "=")
577 if len(nameVal) != 2 {
578 stmt.Close()
579 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
580 }
581 column, value := nameVal[0], nameVal[1]
582 ctype, ok := c.db.columnType(stmt.table, column)
583 if !ok {
584 stmt.Close()
585 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
586 }
587 stmt.colName = append(stmt.colName, column)
588
589 if !strings.HasPrefix(value, "?") {
590 var subsetVal any
591
592 switch ctype {
593 case "string":
594 subsetVal = []byte(value)
595 case "blob":
596 subsetVal = []byte(value)
597 case "int32":
598 i, err := strconv.Atoi(value)
599 if err != nil {
600 stmt.Close()
601 return nil, errf("invalid conversion to int32 from %q", value)
602 }
603 subsetVal = int64(i)
604 case "table":
605 c.skipDirtySession = true
606 vparts := strings.Split(value, "!")
607
608 substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
609 if err != nil {
610 return nil, err
611 }
612 cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
613 substmt.Close()
614 if err != nil {
615 return nil, err
616 }
617 subsetVal = cursor
618 default:
619 stmt.Close()
620 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
621 }
622 stmt.colValue = append(stmt.colValue, subsetVal)
623 } else {
624 stmt.placeholders++
625 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
626 stmt.colValue = append(stmt.colValue, value)
627 }
628 }
629 return stmt, nil
630 }
631
632
633 var hookPrepareBadConn func() bool
634
635 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
636 panic("use PrepareContext")
637 }
638
639 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
640 c.numPrepare++
641 if c.db == nil {
642 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
643 }
644
645 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
646 return nil, fakeError{Message: "Prepare: Sticky Bad", Wrapped: driver.ErrBadConn}
647 }
648
649 c.touchMem()
650 var firstStmt, prev *fakeStmt
651 for _, query := range strings.Split(query, ";") {
652 parts := strings.Split(query, "|")
653 if len(parts) < 1 {
654 return nil, errf("empty query")
655 }
656 stmt := &fakeStmt{q: query, c: c, memToucher: c}
657 if firstStmt == nil {
658 firstStmt = stmt
659 }
660 if len(parts) >= 3 {
661 switch parts[0] {
662 case "PANIC":
663 stmt.panic = parts[1]
664 parts = parts[2:]
665 case "WAIT":
666 wait, err := time.ParseDuration(parts[1])
667 if err != nil {
668 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
669 }
670 parts = parts[2:]
671 stmt.wait = wait
672 }
673 }
674 cmd := parts[0]
675 stmt.cmd = cmd
676 parts = parts[1:]
677
678 if c.waiter != nil {
679 c.waiter(ctx)
680 if err := ctx.Err(); err != nil {
681 return nil, err
682 }
683 }
684
685 if stmt.wait > 0 {
686 wait := time.NewTimer(stmt.wait)
687 select {
688 case <-wait.C:
689 case <-ctx.Done():
690 wait.Stop()
691 return nil, ctx.Err()
692 }
693 }
694
695 c.incrStat(&c.stmtsMade)
696 var err error
697 switch cmd {
698 case "WIPE":
699
700 case "SELECT":
701 stmt, err = c.prepareSelect(stmt, parts)
702 case "CREATE":
703 stmt, err = c.prepareCreate(stmt, parts)
704 case "INSERT":
705 stmt, err = c.prepareInsert(ctx, stmt, parts)
706 case "NOSERT":
707
708
709 stmt, err = c.prepareInsert(ctx, stmt, parts)
710 default:
711 stmt.Close()
712 return nil, errf("unsupported command type %q", cmd)
713 }
714 if err != nil {
715 return nil, err
716 }
717 if prev != nil {
718 prev.next = stmt
719 }
720 prev = stmt
721 }
722 return firstStmt, nil
723 }
724
725 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
726 if s.panic == "ColumnConverter" {
727 panic(s.panic)
728 }
729 if len(s.placeholderConverter) == 0 {
730 return driver.DefaultParameterConverter
731 }
732 return s.placeholderConverter[idx]
733 }
734
735 func (s *fakeStmt) Close() error {
736 if s.panic == "Close" {
737 panic(s.panic)
738 }
739 if s.c == nil {
740 panic("nil conn in fakeStmt.Close")
741 }
742 if s.c.db == nil {
743 panic("in fakeStmt.Close, conn's db is nil (already closed)")
744 }
745 s.touchMem()
746 if !s.closed {
747 s.c.incrStat(&s.c.stmtsClosed)
748 s.closed = true
749 }
750 if s.next != nil {
751 s.next.Close()
752 }
753 return nil
754 }
755
756 var errClosed = errors.New("fakedb: statement has been closed")
757
758
759 var hookExecBadConn func() bool
760
761 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
762 panic("Using ExecContext")
763 }
764
765 var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
766
767 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
768 if s.panic == "Exec" {
769 panic(s.panic)
770 }
771 if s.closed {
772 return nil, errClosed
773 }
774
775 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
776 return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn}
777 }
778 if s.c.isDirtyAndMark() {
779 return nil, errFakeConnSessionDirty
780 }
781
782 err := checkSubsetTypes(s.c.db.allowAny, args)
783 if err != nil {
784 return nil, err
785 }
786 s.touchMem()
787
788 if s.wait > 0 {
789 time.Sleep(s.wait)
790 }
791
792 select {
793 default:
794 case <-ctx.Done():
795 return nil, ctx.Err()
796 }
797
798 db := s.c.db
799 switch s.cmd {
800 case "WIPE":
801 db.wipe()
802 return driver.ResultNoRows, nil
803 case "CREATE":
804 if err := db.createTable(s.table, s.colName, s.colType); err != nil {
805 return nil, err
806 }
807 return driver.ResultNoRows, nil
808 case "INSERT":
809 return s.execInsert(args, true)
810 case "NOSERT":
811
812
813 return s.execInsert(args, false)
814 }
815 return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
816 }
817
818
819
820
821 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
822 db := s.c.db
823 if len(args) != s.placeholders {
824 panic("error in pkg db; should only get here if size is correct")
825 }
826 db.mu.Lock()
827 t, ok := db.table(s.table)
828 db.mu.Unlock()
829 if !ok {
830 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
831 }
832
833 t.mu.Lock()
834 defer t.mu.Unlock()
835
836 var cols []any
837 if doInsert {
838 cols = make([]any, len(t.colname))
839 }
840 argPos := 0
841 for n, colname := range s.colName {
842 colidx := t.columnIndex(colname)
843 if colidx == -1 {
844 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
845 }
846 var val any
847 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
848 if strvalue == "?" {
849 val = args[argPos].Value
850 } else {
851
852 for _, a := range args {
853 if a.Name == strvalue[1:] {
854 val = a.Value
855 break
856 }
857 }
858 }
859 argPos++
860 } else {
861 val = s.colValue[n]
862 }
863 if doInsert {
864 cols[colidx] = val
865 }
866 }
867
868 if doInsert {
869 t.rows = append(t.rows, &row{cols: cols})
870 }
871 return driver.RowsAffected(1), nil
872 }
873
874
875 var hookQueryBadConn func() bool
876
877 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
878 panic("Use QueryContext")
879 }
880
881 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
882 if s.panic == "Query" {
883 panic(s.panic)
884 }
885 if s.closed {
886 return nil, errClosed
887 }
888
889 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
890 return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn}
891 }
892 if s.c.isDirtyAndMark() {
893 return nil, errFakeConnSessionDirty
894 }
895
896 err := checkSubsetTypes(s.c.db.allowAny, args)
897 if err != nil {
898 return nil, err
899 }
900
901 s.touchMem()
902 db := s.c.db
903 if len(args) != s.placeholders {
904 panic("error in pkg db; should only get here if size is correct")
905 }
906
907 setMRows := make([][]*row, 0, 1)
908 setColumns := make([][]string, 0, 1)
909 setColType := make([][]string, 0, 1)
910
911 for {
912 db.mu.Lock()
913 t, ok := db.table(s.table)
914 db.mu.Unlock()
915 if !ok {
916 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
917 }
918
919 if s.table == "magicquery" {
920 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
921 if args[0].Value == "sleep" {
922 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
923 }
924 }
925 }
926 if s.table == "tx_status" && s.colName[0] == "tx_status" {
927 txStatus := "autocommit"
928 if s.c.currTx != nil {
929 txStatus = "transaction"
930 }
931 cursor := &rowsCursor{
932 parentMem: s.c,
933 posRow: -1,
934 rows: [][]*row{
935 {
936 {
937 cols: []any{
938 txStatus,
939 },
940 },
941 },
942 },
943 cols: [][]string{
944 {
945 "tx_status",
946 },
947 },
948 colType: [][]string{
949 {
950 "string",
951 },
952 },
953 errPos: -1,
954 }
955 return cursor, nil
956 }
957
958 t.mu.Lock()
959
960 colIdx := make(map[string]int)
961 for _, name := range s.colName {
962 idx := t.columnIndex(name)
963 if idx == -1 {
964 t.mu.Unlock()
965 return nil, fmt.Errorf("fakedb: unknown column name %q", name)
966 }
967 colIdx[name] = idx
968 }
969
970 mrows := []*row{}
971 rows:
972 for _, trow := range t.rows {
973
974
975
976 for _, wcol := range s.whereCol {
977 idx := t.columnIndex(wcol.Column)
978 if idx == -1 {
979 t.mu.Unlock()
980 return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol)
981 }
982 tcol := trow.cols[idx]
983 if bs, ok := tcol.([]byte); ok {
984
985 tcol = string(bs)
986 }
987 var argValue any
988 if wcol.Placeholder == "?" {
989 argValue = args[wcol.Ordinal-1].Value
990 } else {
991
992 for _, a := range args {
993 if a.Name == wcol.Placeholder[1:] {
994 argValue = a.Value
995 break
996 }
997 }
998 }
999 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
1000 continue rows
1001 }
1002 }
1003 mrow := &row{cols: make([]any, len(s.colName))}
1004 for seli, name := range s.colName {
1005 mrow.cols[seli] = trow.cols[colIdx[name]]
1006 }
1007 mrows = append(mrows, mrow)
1008 }
1009
1010 var colType []string
1011 for _, column := range s.colName {
1012 colType = append(colType, t.coltype[t.columnIndex(column)])
1013 }
1014
1015 t.mu.Unlock()
1016
1017 setMRows = append(setMRows, mrows)
1018 setColumns = append(setColumns, s.colName)
1019 setColType = append(setColType, colType)
1020
1021 if s.next == nil {
1022 break
1023 }
1024 s = s.next
1025 }
1026
1027 cursor := &rowsCursor{
1028 parentMem: s.c,
1029 posRow: -1,
1030 rows: setMRows,
1031 cols: setColumns,
1032 colType: setColType,
1033 errPos: -1,
1034 }
1035 return cursor, nil
1036 }
1037
1038 func (s *fakeStmt) NumInput() int {
1039 if s.panic == "NumInput" {
1040 panic(s.panic)
1041 }
1042 return s.placeholders
1043 }
1044
1045
1046 var hookCommitBadConn func() bool
1047
1048 func (tx *fakeTx) Commit() error {
1049 tx.c.currTx = nil
1050 if hookCommitBadConn != nil && hookCommitBadConn() {
1051 return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn}
1052 }
1053 tx.c.touchMem()
1054 return nil
1055 }
1056
1057
1058 var hookRollbackBadConn func() bool
1059
1060 func (tx *fakeTx) Rollback() error {
1061 tx.c.currTx = nil
1062 if hookRollbackBadConn != nil && hookRollbackBadConn() {
1063 return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn}
1064 }
1065 tx.c.touchMem()
1066 return nil
1067 }
1068
1069 type rowsCursor struct {
1070 parentMem memToucher
1071 cols [][]string
1072 colType [][]string
1073 posSet int
1074 posRow int
1075 rows [][]*row
1076 closed bool
1077
1078
1079 errPos int
1080 err error
1081
1082
1083
1084
1085 bytesClone map[*byte][]byte
1086
1087
1088
1089
1090
1091 line int64
1092
1093
1094 closeErr error
1095 }
1096
1097 func (rc *rowsCursor) touchMem() {
1098 rc.parentMem.touchMem()
1099 rc.line++
1100 }
1101
1102 func (rc *rowsCursor) Close() error {
1103 rc.touchMem()
1104 rc.parentMem.touchMem()
1105 rc.closed = true
1106 return rc.closeErr
1107 }
1108
1109 func (rc *rowsCursor) Columns() []string {
1110 return rc.cols[rc.posSet]
1111 }
1112
1113 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
1114 return colTypeToReflectType(rc.colType[rc.posSet][index])
1115 }
1116
1117 var rowsCursorNextHook func(dest []driver.Value) error
1118
1119 func (rc *rowsCursor) Next(dest []driver.Value) error {
1120 if rowsCursorNextHook != nil {
1121 return rowsCursorNextHook(dest)
1122 }
1123
1124 if rc.closed {
1125 return errors.New("fakedb: cursor is closed")
1126 }
1127 rc.touchMem()
1128 rc.posRow++
1129 if rc.posRow == rc.errPos {
1130 return rc.err
1131 }
1132 if rc.posRow >= len(rc.rows[rc.posSet]) {
1133 return io.EOF
1134 }
1135 for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
1136
1137
1138
1139
1140
1141
1142 dest[i] = v
1143
1144 if bs, ok := v.([]byte); ok {
1145 if rc.bytesClone == nil {
1146 rc.bytesClone = make(map[*byte][]byte)
1147 }
1148 clone, ok := rc.bytesClone[&bs[0]]
1149 if !ok {
1150 clone = make([]byte, len(bs))
1151 copy(clone, bs)
1152 rc.bytesClone[&bs[0]] = clone
1153 }
1154 dest[i] = clone
1155 }
1156 }
1157 return nil
1158 }
1159
1160 func (rc *rowsCursor) HasNextResultSet() bool {
1161 rc.touchMem()
1162 return rc.posSet < len(rc.rows)-1
1163 }
1164
1165 func (rc *rowsCursor) NextResultSet() error {
1166 rc.touchMem()
1167 if rc.HasNextResultSet() {
1168 rc.posSet++
1169 rc.posRow = -1
1170 return nil
1171 }
1172 return io.EOF
1173 }
1174
1175
1176
1177
1178
1179
1180
1181 type fakeDriverString struct{}
1182
1183 func (fakeDriverString) ConvertValue(v any) (driver.Value, error) {
1184 switch c := v.(type) {
1185 case string, []byte:
1186 return v, nil
1187 case *string:
1188 if c == nil {
1189 return nil, nil
1190 }
1191 return *c, nil
1192 }
1193 return fmt.Sprintf("%v", v), nil
1194 }
1195
1196 type anyTypeConverter struct{}
1197
1198 func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) {
1199 return v, nil
1200 }
1201
1202 func converterForType(typ string) driver.ValueConverter {
1203 switch typ {
1204 case "bool":
1205 return driver.Bool
1206 case "nullbool":
1207 return driver.Null{Converter: driver.Bool}
1208 case "byte", "int16":
1209 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1210 case "int32":
1211 return driver.Int32
1212 case "nullbyte", "nullint32", "nullint16":
1213 return driver.Null{Converter: driver.DefaultParameterConverter}
1214 case "string":
1215 return driver.NotNull{Converter: fakeDriverString{}}
1216 case "nullstring":
1217 return driver.Null{Converter: fakeDriverString{}}
1218 case "int64":
1219
1220 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1221 case "nullint64":
1222
1223 return driver.Null{Converter: driver.DefaultParameterConverter}
1224 case "float64":
1225
1226 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1227 case "nullfloat64":
1228
1229 return driver.Null{Converter: driver.DefaultParameterConverter}
1230 case "datetime":
1231 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1232 case "nulldatetime":
1233 return driver.Null{Converter: driver.DefaultParameterConverter}
1234 case "any":
1235 return anyTypeConverter{}
1236 }
1237 panic("invalid fakedb column type of " + typ)
1238 }
1239
1240 func colTypeToReflectType(typ string) reflect.Type {
1241 switch typ {
1242 case "bool":
1243 return reflect.TypeOf(false)
1244 case "nullbool":
1245 return reflect.TypeOf(NullBool{})
1246 case "int16":
1247 return reflect.TypeOf(int16(0))
1248 case "nullint16":
1249 return reflect.TypeOf(NullInt16{})
1250 case "int32":
1251 return reflect.TypeOf(int32(0))
1252 case "nullint32":
1253 return reflect.TypeOf(NullInt32{})
1254 case "string":
1255 return reflect.TypeOf("")
1256 case "nullstring":
1257 return reflect.TypeOf(NullString{})
1258 case "int64":
1259 return reflect.TypeOf(int64(0))
1260 case "nullint64":
1261 return reflect.TypeOf(NullInt64{})
1262 case "float64":
1263 return reflect.TypeOf(float64(0))
1264 case "nullfloat64":
1265 return reflect.TypeOf(NullFloat64{})
1266 case "datetime":
1267 return reflect.TypeOf(time.Time{})
1268 case "any":
1269 return reflect.TypeOf(new(any)).Elem()
1270 }
1271 panic("invalid fakedb column type of " + typ)
1272 }
1273
View as plain text