1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package sql
17
18 import (
19 "context"
20 "database/sql/driver"
21 "errors"
22 "fmt"
23 "io"
24 "reflect"
25 "runtime"
26 "sort"
27 "strconv"
28 "sync"
29 "sync/atomic"
30 "time"
31 )
32
33 var (
34 driversMu sync.RWMutex
35 drivers = make(map[string]driver.Driver)
36 )
37
38
39 var nowFunc = time.Now
40
41
42
43
44 func Register(name string, driver driver.Driver) {
45 driversMu.Lock()
46 defer driversMu.Unlock()
47 if driver == nil {
48 panic("sql: Register driver is nil")
49 }
50 if _, dup := drivers[name]; dup {
51 panic("sql: Register called twice for driver " + name)
52 }
53 drivers[name] = driver
54 }
55
56 func unregisterAllDrivers() {
57 driversMu.Lock()
58 defer driversMu.Unlock()
59
60 drivers = make(map[string]driver.Driver)
61 }
62
63
64 func Drivers() []string {
65 driversMu.RLock()
66 defer driversMu.RUnlock()
67 list := make([]string, 0, len(drivers))
68 for name := range drivers {
69 list = append(list, name)
70 }
71 sort.Strings(list)
72 return list
73 }
74
75
76
77
78
79
80
81 type NamedArg struct {
82 _NamedFieldsRequired struct{}
83
84
85
86
87
88
89
90 Name string
91
92
93
94
95 Value any
96 }
97
98
99
100
101
102
103
104
105
106
107
108
109
110 func Named(name string, value any) NamedArg {
111
112
113
114
115 return NamedArg{Name: name, Value: value}
116 }
117
118
119 type IsolationLevel int
120
121
122
123
124
125 const (
126 LevelDefault IsolationLevel = iota
127 LevelReadUncommitted
128 LevelReadCommitted
129 LevelWriteCommitted
130 LevelRepeatableRead
131 LevelSnapshot
132 LevelSerializable
133 LevelLinearizable
134 )
135
136
137 func (i IsolationLevel) String() string {
138 switch i {
139 case LevelDefault:
140 return "Default"
141 case LevelReadUncommitted:
142 return "Read Uncommitted"
143 case LevelReadCommitted:
144 return "Read Committed"
145 case LevelWriteCommitted:
146 return "Write Committed"
147 case LevelRepeatableRead:
148 return "Repeatable Read"
149 case LevelSnapshot:
150 return "Snapshot"
151 case LevelSerializable:
152 return "Serializable"
153 case LevelLinearizable:
154 return "Linearizable"
155 default:
156 return "IsolationLevel(" + strconv.Itoa(int(i)) + ")"
157 }
158 }
159
160 var _ fmt.Stringer = LevelDefault
161
162
163 type TxOptions struct {
164
165
166 Isolation IsolationLevel
167 ReadOnly bool
168 }
169
170
171
172
173 type RawBytes []byte
174
175
176
177
178
179
180
181
182
183
184
185
186
187 type NullString struct {
188 String string
189 Valid bool
190 }
191
192
193 func (ns *NullString) Scan(value any) error {
194 if value == nil {
195 ns.String, ns.Valid = "", false
196 return nil
197 }
198 ns.Valid = true
199 return convertAssign(&ns.String, value)
200 }
201
202
203 func (ns NullString) Value() (driver.Value, error) {
204 if !ns.Valid {
205 return nil, nil
206 }
207 return ns.String, nil
208 }
209
210
211
212
213 type NullInt64 struct {
214 Int64 int64
215 Valid bool
216 }
217
218
219 func (n *NullInt64) Scan(value any) error {
220 if value == nil {
221 n.Int64, n.Valid = 0, false
222 return nil
223 }
224 n.Valid = true
225 return convertAssign(&n.Int64, value)
226 }
227
228
229 func (n NullInt64) Value() (driver.Value, error) {
230 if !n.Valid {
231 return nil, nil
232 }
233 return n.Int64, nil
234 }
235
236
237
238
239 type NullInt32 struct {
240 Int32 int32
241 Valid bool
242 }
243
244
245 func (n *NullInt32) Scan(value any) error {
246 if value == nil {
247 n.Int32, n.Valid = 0, false
248 return nil
249 }
250 n.Valid = true
251 return convertAssign(&n.Int32, value)
252 }
253
254
255 func (n NullInt32) Value() (driver.Value, error) {
256 if !n.Valid {
257 return nil, nil
258 }
259 return int64(n.Int32), nil
260 }
261
262
263
264
265 type NullInt16 struct {
266 Int16 int16
267 Valid bool
268 }
269
270
271 func (n *NullInt16) Scan(value any) error {
272 if value == nil {
273 n.Int16, n.Valid = 0, false
274 return nil
275 }
276 err := convertAssign(&n.Int16, value)
277 n.Valid = err == nil
278 return err
279 }
280
281
282 func (n NullInt16) Value() (driver.Value, error) {
283 if !n.Valid {
284 return nil, nil
285 }
286 return int64(n.Int16), nil
287 }
288
289
290
291
292 type NullByte struct {
293 Byte byte
294 Valid bool
295 }
296
297
298 func (n *NullByte) Scan(value any) error {
299 if value == nil {
300 n.Byte, n.Valid = 0, false
301 return nil
302 }
303 err := convertAssign(&n.Byte, value)
304 n.Valid = err == nil
305 return err
306 }
307
308
309 func (n NullByte) Value() (driver.Value, error) {
310 if !n.Valid {
311 return nil, nil
312 }
313 return int64(n.Byte), nil
314 }
315
316
317
318
319 type NullFloat64 struct {
320 Float64 float64
321 Valid bool
322 }
323
324
325 func (n *NullFloat64) Scan(value any) error {
326 if value == nil {
327 n.Float64, n.Valid = 0, false
328 return nil
329 }
330 n.Valid = true
331 return convertAssign(&n.Float64, value)
332 }
333
334
335 func (n NullFloat64) Value() (driver.Value, error) {
336 if !n.Valid {
337 return nil, nil
338 }
339 return n.Float64, nil
340 }
341
342
343
344
345 type NullBool struct {
346 Bool bool
347 Valid bool
348 }
349
350
351 func (n *NullBool) Scan(value any) error {
352 if value == nil {
353 n.Bool, n.Valid = false, false
354 return nil
355 }
356 n.Valid = true
357 return convertAssign(&n.Bool, value)
358 }
359
360
361 func (n NullBool) Value() (driver.Value, error) {
362 if !n.Valid {
363 return nil, nil
364 }
365 return n.Bool, nil
366 }
367
368
369
370
371 type NullTime struct {
372 Time time.Time
373 Valid bool
374 }
375
376
377 func (n *NullTime) Scan(value any) error {
378 if value == nil {
379 n.Time, n.Valid = time.Time{}, false
380 return nil
381 }
382 n.Valid = true
383 return convertAssign(&n.Time, value)
384 }
385
386
387 func (n NullTime) Value() (driver.Value, error) {
388 if !n.Valid {
389 return nil, nil
390 }
391 return n.Time, nil
392 }
393
394
395 type Scanner interface {
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414 Scan(src any) error
415 }
416
417
418
419
420
421
422
423
424
425 type Out struct {
426 _NamedFieldsRequired struct{}
427
428
429
430 Dest any
431
432
433
434
435 In bool
436 }
437
438
439
440
441 var ErrNoRows = errors.New("sql: no rows in result set")
442
443
444
445
446
447
448
449
450
451
452
453
454
455 type DB struct {
456
457 waitDuration atomic.Int64
458
459 connector driver.Connector
460
461
462
463 numClosed atomic.Uint64
464
465 mu sync.Mutex
466 freeConn []*driverConn
467 connRequests map[uint64]chan connRequest
468 nextRequest uint64
469 numOpen int
470
471
472
473
474
475 openerCh chan struct{}
476 closed bool
477 dep map[finalCloser]depSet
478 lastPut map[*driverConn]string
479 maxIdleCount int
480 maxOpen int
481 maxLifetime time.Duration
482 maxIdleTime time.Duration
483 cleanerCh chan struct{}
484 waitCount int64
485 maxIdleClosed int64
486 maxIdleTimeClosed int64
487 maxLifetimeClosed int64
488
489 stop func()
490 }
491
492
493 type connReuseStrategy uint8
494
495 const (
496
497 alwaysNewConn connReuseStrategy = iota
498
499
500
501 cachedOrNewConn
502 )
503
504
505
506
507
508 type driverConn struct {
509 db *DB
510 createdAt time.Time
511
512 sync.Mutex
513 ci driver.Conn
514 needReset bool
515 closed bool
516 finalClosed bool
517 openStmt map[*driverStmt]bool
518
519
520 inUse bool
521 returnedAt time.Time
522 onPut []func()
523 dbmuClosed bool
524 }
525
526 func (dc *driverConn) releaseConn(err error) {
527 dc.db.putConn(dc, err, true)
528 }
529
530 func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
531 dc.Lock()
532 defer dc.Unlock()
533 delete(dc.openStmt, ds)
534 }
535
536 func (dc *driverConn) expired(timeout time.Duration) bool {
537 if timeout <= 0 {
538 return false
539 }
540 return dc.createdAt.Add(timeout).Before(nowFunc())
541 }
542
543
544
545 func (dc *driverConn) resetSession(ctx context.Context) error {
546 dc.Lock()
547 defer dc.Unlock()
548
549 if !dc.needReset {
550 return nil
551 }
552 if cr, ok := dc.ci.(driver.SessionResetter); ok {
553 return cr.ResetSession(ctx)
554 }
555 return nil
556 }
557
558
559
560 func (dc *driverConn) validateConnection(needsReset bool) bool {
561 dc.Lock()
562 defer dc.Unlock()
563
564 if needsReset {
565 dc.needReset = true
566 }
567 if cv, ok := dc.ci.(driver.Validator); ok {
568 return cv.IsValid()
569 }
570 return true
571 }
572
573
574
575 func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) {
576 si, err := ctxDriverPrepare(ctx, dc.ci, query)
577 if err != nil {
578 return nil, err
579 }
580 ds := &driverStmt{Locker: dc, si: si}
581
582
583 if cg != nil {
584 return ds, nil
585 }
586
587
588
589
590
591 if dc.openStmt == nil {
592 dc.openStmt = make(map[*driverStmt]bool)
593 }
594 dc.openStmt[ds] = true
595 return ds, nil
596 }
597
598
599 func (dc *driverConn) closeDBLocked() func() error {
600 dc.Lock()
601 defer dc.Unlock()
602 if dc.closed {
603 return func() error { return errors.New("sql: duplicate driverConn close") }
604 }
605 dc.closed = true
606 return dc.db.removeDepLocked(dc, dc)
607 }
608
609 func (dc *driverConn) Close() error {
610 dc.Lock()
611 if dc.closed {
612 dc.Unlock()
613 return errors.New("sql: duplicate driverConn close")
614 }
615 dc.closed = true
616 dc.Unlock()
617
618
619 dc.db.mu.Lock()
620 dc.dbmuClosed = true
621 fn := dc.db.removeDepLocked(dc, dc)
622 dc.db.mu.Unlock()
623 return fn()
624 }
625
626 func (dc *driverConn) finalClose() error {
627 var err error
628
629
630
631 var openStmt []*driverStmt
632 withLock(dc, func() {
633 openStmt = make([]*driverStmt, 0, len(dc.openStmt))
634 for ds := range dc.openStmt {
635 openStmt = append(openStmt, ds)
636 }
637 dc.openStmt = nil
638 })
639 for _, ds := range openStmt {
640 ds.Close()
641 }
642 withLock(dc, func() {
643 dc.finalClosed = true
644 err = dc.ci.Close()
645 dc.ci = nil
646 })
647
648 dc.db.mu.Lock()
649 dc.db.numOpen--
650 dc.db.maybeOpenNewConnections()
651 dc.db.mu.Unlock()
652
653 dc.db.numClosed.Add(1)
654 return err
655 }
656
657
658
659
660 type driverStmt struct {
661 sync.Locker
662 si driver.Stmt
663 closed bool
664 closeErr error
665 }
666
667
668
669 func (ds *driverStmt) Close() error {
670 ds.Lock()
671 defer ds.Unlock()
672 if ds.closed {
673 return ds.closeErr
674 }
675 ds.closed = true
676 ds.closeErr = ds.si.Close()
677 return ds.closeErr
678 }
679
680
681 type depSet map[any]bool
682
683
684
685 type finalCloser interface {
686
687
688 finalClose() error
689 }
690
691
692
693 func (db *DB) addDep(x finalCloser, dep any) {
694 db.mu.Lock()
695 defer db.mu.Unlock()
696 db.addDepLocked(x, dep)
697 }
698
699 func (db *DB) addDepLocked(x finalCloser, dep any) {
700 if db.dep == nil {
701 db.dep = make(map[finalCloser]depSet)
702 }
703 xdep := db.dep[x]
704 if xdep == nil {
705 xdep = make(depSet)
706 db.dep[x] = xdep
707 }
708 xdep[dep] = true
709 }
710
711
712
713
714
715 func (db *DB) removeDep(x finalCloser, dep any) error {
716 db.mu.Lock()
717 fn := db.removeDepLocked(x, dep)
718 db.mu.Unlock()
719 return fn()
720 }
721
722 func (db *DB) removeDepLocked(x finalCloser, dep any) func() error {
723 xdep, ok := db.dep[x]
724 if !ok {
725 panic(fmt.Sprintf("unpaired removeDep: no deps for %T", x))
726 }
727
728 l0 := len(xdep)
729 delete(xdep, dep)
730
731 switch len(xdep) {
732 case l0:
733
734 panic(fmt.Sprintf("unpaired removeDep: no %T dep on %T", dep, x))
735 case 0:
736
737 delete(db.dep, x)
738 return x.finalClose
739 default:
740
741 return func() error { return nil }
742 }
743 }
744
745
746
747
748
749
750 var connectionRequestQueueSize = 1000000
751
752 type dsnConnector struct {
753 dsn string
754 driver driver.Driver
755 }
756
757 func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
758 return t.driver.Open(t.dsn)
759 }
760
761 func (t dsnConnector) Driver() driver.Driver {
762 return t.driver
763 }
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781 func OpenDB(c driver.Connector) *DB {
782 ctx, cancel := context.WithCancel(context.Background())
783 db := &DB{
784 connector: c,
785 openerCh: make(chan struct{}, connectionRequestQueueSize),
786 lastPut: make(map[*driverConn]string),
787 connRequests: make(map[uint64]chan connRequest),
788 stop: cancel,
789 }
790
791 go db.connectionOpener(ctx)
792
793 return db
794 }
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813 func Open(driverName, dataSourceName string) (*DB, error) {
814 driversMu.RLock()
815 driveri, ok := drivers[driverName]
816 driversMu.RUnlock()
817 if !ok {
818 return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
819 }
820
821 if driverCtx, ok := driveri.(driver.DriverContext); ok {
822 connector, err := driverCtx.OpenConnector(dataSourceName)
823 if err != nil {
824 return nil, err
825 }
826 return OpenDB(connector), nil
827 }
828
829 return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
830 }
831
832 func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error {
833 var err error
834 if pinger, ok := dc.ci.(driver.Pinger); ok {
835 withLock(dc, func() {
836 err = pinger.Ping(ctx)
837 })
838 }
839 release(err)
840 return err
841 }
842
843
844
845 func (db *DB) PingContext(ctx context.Context) error {
846 var dc *driverConn
847 var err error
848
849 err = db.retry(func(strategy connReuseStrategy) error {
850 dc, err = db.conn(ctx, strategy)
851 return err
852 })
853
854 if err != nil {
855 return err
856 }
857
858 return db.pingDC(ctx, dc, dc.releaseConn)
859 }
860
861
862
863
864
865
866 func (db *DB) Ping() error {
867 return db.PingContext(context.Background())
868 }
869
870
871
872
873
874
875
876 func (db *DB) Close() error {
877 db.mu.Lock()
878 if db.closed {
879 db.mu.Unlock()
880 return nil
881 }
882 if db.cleanerCh != nil {
883 close(db.cleanerCh)
884 }
885 var err error
886 fns := make([]func() error, 0, len(db.freeConn))
887 for _, dc := range db.freeConn {
888 fns = append(fns, dc.closeDBLocked())
889 }
890 db.freeConn = nil
891 db.closed = true
892 for _, req := range db.connRequests {
893 close(req)
894 }
895 db.mu.Unlock()
896 for _, fn := range fns {
897 err1 := fn()
898 if err1 != nil {
899 err = err1
900 }
901 }
902 db.stop()
903 if c, ok := db.connector.(io.Closer); ok {
904 err1 := c.Close()
905 if err1 != nil {
906 err = err1
907 }
908 }
909 return err
910 }
911
912 const defaultMaxIdleConns = 2
913
914 func (db *DB) maxIdleConnsLocked() int {
915 n := db.maxIdleCount
916 switch {
917 case n == 0:
918
919 return defaultMaxIdleConns
920 case n < 0:
921 return 0
922 default:
923 return n
924 }
925 }
926
927 func (db *DB) shortestIdleTimeLocked() time.Duration {
928 if db.maxIdleTime <= 0 {
929 return db.maxLifetime
930 }
931 if db.maxLifetime <= 0 {
932 return db.maxIdleTime
933 }
934
935 min := db.maxIdleTime
936 if min > db.maxLifetime {
937 min = db.maxLifetime
938 }
939 return min
940 }
941
942
943
944
945
946
947
948
949
950
951
952 func (db *DB) SetMaxIdleConns(n int) {
953 db.mu.Lock()
954 if n > 0 {
955 db.maxIdleCount = n
956 } else {
957
958 db.maxIdleCount = -1
959 }
960
961 if db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen {
962 db.maxIdleCount = db.maxOpen
963 }
964 var closing []*driverConn
965 idleCount := len(db.freeConn)
966 maxIdle := db.maxIdleConnsLocked()
967 if idleCount > maxIdle {
968 closing = db.freeConn[maxIdle:]
969 db.freeConn = db.freeConn[:maxIdle]
970 }
971 db.maxIdleClosed += int64(len(closing))
972 db.mu.Unlock()
973 for _, c := range closing {
974 c.Close()
975 }
976 }
977
978
979
980
981
982
983
984
985
986 func (db *DB) SetMaxOpenConns(n int) {
987 db.mu.Lock()
988 db.maxOpen = n
989 if n < 0 {
990 db.maxOpen = 0
991 }
992 syncMaxIdle := db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen
993 db.mu.Unlock()
994 if syncMaxIdle {
995 db.SetMaxIdleConns(n)
996 }
997 }
998
999
1000
1001
1002
1003
1004 func (db *DB) SetConnMaxLifetime(d time.Duration) {
1005 if d < 0 {
1006 d = 0
1007 }
1008 db.mu.Lock()
1009
1010 if d > 0 && d < db.maxLifetime && db.cleanerCh != nil {
1011 select {
1012 case db.cleanerCh <- struct{}{}:
1013 default:
1014 }
1015 }
1016 db.maxLifetime = d
1017 db.startCleanerLocked()
1018 db.mu.Unlock()
1019 }
1020
1021
1022
1023
1024
1025
1026 func (db *DB) SetConnMaxIdleTime(d time.Duration) {
1027 if d < 0 {
1028 d = 0
1029 }
1030 db.mu.Lock()
1031 defer db.mu.Unlock()
1032
1033
1034 if d > 0 && d < db.maxIdleTime && db.cleanerCh != nil {
1035 select {
1036 case db.cleanerCh <- struct{}{}:
1037 default:
1038 }
1039 }
1040 db.maxIdleTime = d
1041 db.startCleanerLocked()
1042 }
1043
1044
1045 func (db *DB) startCleanerLocked() {
1046 if (db.maxLifetime > 0 || db.maxIdleTime > 0) && db.numOpen > 0 && db.cleanerCh == nil {
1047 db.cleanerCh = make(chan struct{}, 1)
1048 go db.connectionCleaner(db.shortestIdleTimeLocked())
1049 }
1050 }
1051
1052 func (db *DB) connectionCleaner(d time.Duration) {
1053 const minInterval = time.Second
1054
1055 if d < minInterval {
1056 d = minInterval
1057 }
1058 t := time.NewTimer(d)
1059
1060 for {
1061 select {
1062 case <-t.C:
1063 case <-db.cleanerCh:
1064 }
1065
1066 db.mu.Lock()
1067
1068 d = db.shortestIdleTimeLocked()
1069 if db.closed || db.numOpen == 0 || d <= 0 {
1070 db.cleanerCh = nil
1071 db.mu.Unlock()
1072 return
1073 }
1074
1075 d, closing := db.connectionCleanerRunLocked(d)
1076 db.mu.Unlock()
1077 for _, c := range closing {
1078 c.Close()
1079 }
1080
1081 if d < minInterval {
1082 d = minInterval
1083 }
1084
1085 if !t.Stop() {
1086 select {
1087 case <-t.C:
1088 default:
1089 }
1090 }
1091 t.Reset(d)
1092 }
1093 }
1094
1095
1096
1097
1098 func (db *DB) connectionCleanerRunLocked(d time.Duration) (time.Duration, []*driverConn) {
1099 var idleClosing int64
1100 var closing []*driverConn
1101 if db.maxIdleTime > 0 {
1102
1103
1104 idleSince := nowFunc().Add(-db.maxIdleTime)
1105 last := len(db.freeConn) - 1
1106 for i := last; i >= 0; i-- {
1107 c := db.freeConn[i]
1108 if c.returnedAt.Before(idleSince) {
1109 i++
1110 closing = db.freeConn[:i:i]
1111 db.freeConn = db.freeConn[i:]
1112 idleClosing = int64(len(closing))
1113 db.maxIdleTimeClosed += idleClosing
1114 break
1115 }
1116 }
1117
1118 if len(db.freeConn) > 0 {
1119 c := db.freeConn[0]
1120 if d2 := c.returnedAt.Sub(idleSince); d2 < d {
1121
1122
1123 d = d2
1124 }
1125 }
1126 }
1127
1128 if db.maxLifetime > 0 {
1129 expiredSince := nowFunc().Add(-db.maxLifetime)
1130 for i := 0; i < len(db.freeConn); i++ {
1131 c := db.freeConn[i]
1132 if c.createdAt.Before(expiredSince) {
1133 closing = append(closing, c)
1134
1135 last := len(db.freeConn) - 1
1136
1137
1138 copy(db.freeConn[i:], db.freeConn[i+1:])
1139 db.freeConn[last] = nil
1140 db.freeConn = db.freeConn[:last]
1141 i--
1142 } else if d2 := c.createdAt.Sub(expiredSince); d2 < d {
1143
1144
1145 d = d2
1146 }
1147 }
1148 db.maxLifetimeClosed += int64(len(closing)) - idleClosing
1149 }
1150
1151 return d, closing
1152 }
1153
1154
1155 type DBStats struct {
1156 MaxOpenConnections int
1157
1158
1159 OpenConnections int
1160 InUse int
1161 Idle int
1162
1163
1164 WaitCount int64
1165 WaitDuration time.Duration
1166 MaxIdleClosed int64
1167 MaxIdleTimeClosed int64
1168 MaxLifetimeClosed int64
1169 }
1170
1171
1172 func (db *DB) Stats() DBStats {
1173 wait := db.waitDuration.Load()
1174
1175 db.mu.Lock()
1176 defer db.mu.Unlock()
1177
1178 stats := DBStats{
1179 MaxOpenConnections: db.maxOpen,
1180
1181 Idle: len(db.freeConn),
1182 OpenConnections: db.numOpen,
1183 InUse: db.numOpen - len(db.freeConn),
1184
1185 WaitCount: db.waitCount,
1186 WaitDuration: time.Duration(wait),
1187 MaxIdleClosed: db.maxIdleClosed,
1188 MaxIdleTimeClosed: db.maxIdleTimeClosed,
1189 MaxLifetimeClosed: db.maxLifetimeClosed,
1190 }
1191 return stats
1192 }
1193
1194
1195
1196
1197 func (db *DB) maybeOpenNewConnections() {
1198 numRequests := len(db.connRequests)
1199 if db.maxOpen > 0 {
1200 numCanOpen := db.maxOpen - db.numOpen
1201 if numRequests > numCanOpen {
1202 numRequests = numCanOpen
1203 }
1204 }
1205 for numRequests > 0 {
1206 db.numOpen++
1207 numRequests--
1208 if db.closed {
1209 return
1210 }
1211 db.openerCh <- struct{}{}
1212 }
1213 }
1214
1215
1216 func (db *DB) connectionOpener(ctx context.Context) {
1217 for {
1218 select {
1219 case <-ctx.Done():
1220 return
1221 case <-db.openerCh:
1222 db.openNewConnection(ctx)
1223 }
1224 }
1225 }
1226
1227
1228 func (db *DB) openNewConnection(ctx context.Context) {
1229
1230
1231
1232 ci, err := db.connector.Connect(ctx)
1233 db.mu.Lock()
1234 defer db.mu.Unlock()
1235 if db.closed {
1236 if err == nil {
1237 ci.Close()
1238 }
1239 db.numOpen--
1240 return
1241 }
1242 if err != nil {
1243 db.numOpen--
1244 db.putConnDBLocked(nil, err)
1245 db.maybeOpenNewConnections()
1246 return
1247 }
1248 dc := &driverConn{
1249 db: db,
1250 createdAt: nowFunc(),
1251 returnedAt: nowFunc(),
1252 ci: ci,
1253 }
1254 if db.putConnDBLocked(dc, err) {
1255 db.addDepLocked(dc, dc)
1256 } else {
1257 db.numOpen--
1258 ci.Close()
1259 }
1260 }
1261
1262
1263
1264
1265 type connRequest struct {
1266 conn *driverConn
1267 err error
1268 }
1269
1270 var errDBClosed = errors.New("sql: database is closed")
1271
1272
1273
1274 func (db *DB) nextRequestKeyLocked() uint64 {
1275 next := db.nextRequest
1276 db.nextRequest++
1277 return next
1278 }
1279
1280
1281 func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
1282 db.mu.Lock()
1283 if db.closed {
1284 db.mu.Unlock()
1285 return nil, errDBClosed
1286 }
1287
1288 select {
1289 default:
1290 case <-ctx.Done():
1291 db.mu.Unlock()
1292 return nil, ctx.Err()
1293 }
1294 lifetime := db.maxLifetime
1295
1296
1297 last := len(db.freeConn) - 1
1298 if strategy == cachedOrNewConn && last >= 0 {
1299
1300
1301 conn := db.freeConn[last]
1302 db.freeConn = db.freeConn[:last]
1303 conn.inUse = true
1304 if conn.expired(lifetime) {
1305 db.maxLifetimeClosed++
1306 db.mu.Unlock()
1307 conn.Close()
1308 return nil, driver.ErrBadConn
1309 }
1310 db.mu.Unlock()
1311
1312
1313 if err := conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
1314 conn.Close()
1315 return nil, err
1316 }
1317
1318 return conn, nil
1319 }
1320
1321
1322
1323 if db.maxOpen > 0 && db.numOpen >= db.maxOpen {
1324
1325
1326 req := make(chan connRequest, 1)
1327 reqKey := db.nextRequestKeyLocked()
1328 db.connRequests[reqKey] = req
1329 db.waitCount++
1330 db.mu.Unlock()
1331
1332 waitStart := nowFunc()
1333
1334
1335 select {
1336 case <-ctx.Done():
1337
1338
1339 db.mu.Lock()
1340 delete(db.connRequests, reqKey)
1341 db.mu.Unlock()
1342
1343 db.waitDuration.Add(int64(time.Since(waitStart)))
1344
1345 select {
1346 default:
1347 case ret, ok := <-req:
1348 if ok && ret.conn != nil {
1349 db.putConn(ret.conn, ret.err, false)
1350 }
1351 }
1352 return nil, ctx.Err()
1353 case ret, ok := <-req:
1354 db.waitDuration.Add(int64(time.Since(waitStart)))
1355
1356 if !ok {
1357 return nil, errDBClosed
1358 }
1359
1360
1361
1362
1363
1364
1365 if strategy == cachedOrNewConn && ret.err == nil && ret.conn.expired(lifetime) {
1366 db.mu.Lock()
1367 db.maxLifetimeClosed++
1368 db.mu.Unlock()
1369 ret.conn.Close()
1370 return nil, driver.ErrBadConn
1371 }
1372 if ret.conn == nil {
1373 return nil, ret.err
1374 }
1375
1376
1377 if err := ret.conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
1378 ret.conn.Close()
1379 return nil, err
1380 }
1381 return ret.conn, ret.err
1382 }
1383 }
1384
1385 db.numOpen++
1386 db.mu.Unlock()
1387 ci, err := db.connector.Connect(ctx)
1388 if err != nil {
1389 db.mu.Lock()
1390 db.numOpen--
1391 db.maybeOpenNewConnections()
1392 db.mu.Unlock()
1393 return nil, err
1394 }
1395 db.mu.Lock()
1396 dc := &driverConn{
1397 db: db,
1398 createdAt: nowFunc(),
1399 returnedAt: nowFunc(),
1400 ci: ci,
1401 inUse: true,
1402 }
1403 db.addDepLocked(dc, dc)
1404 db.mu.Unlock()
1405 return dc, nil
1406 }
1407
1408
1409 var putConnHook func(*DB, *driverConn)
1410
1411
1412
1413
1414 func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
1415 db.mu.Lock()
1416 defer db.mu.Unlock()
1417 if c.inUse {
1418 c.onPut = append(c.onPut, func() {
1419 ds.Close()
1420 })
1421 } else {
1422 c.Lock()
1423 fc := c.finalClosed
1424 c.Unlock()
1425 if !fc {
1426 ds.Close()
1427 }
1428 }
1429 }
1430
1431
1432
1433 const debugGetPut = false
1434
1435
1436
1437 func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
1438 if !errors.Is(err, driver.ErrBadConn) {
1439 if !dc.validateConnection(resetSession) {
1440 err = driver.ErrBadConn
1441 }
1442 }
1443 db.mu.Lock()
1444 if !dc.inUse {
1445 db.mu.Unlock()
1446 if debugGetPut {
1447 fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc])
1448 }
1449 panic("sql: connection returned that was never out")
1450 }
1451
1452 if !errors.Is(err, driver.ErrBadConn) && dc.expired(db.maxLifetime) {
1453 db.maxLifetimeClosed++
1454 err = driver.ErrBadConn
1455 }
1456 if debugGetPut {
1457 db.lastPut[dc] = stack()
1458 }
1459 dc.inUse = false
1460 dc.returnedAt = nowFunc()
1461
1462 for _, fn := range dc.onPut {
1463 fn()
1464 }
1465 dc.onPut = nil
1466
1467 if errors.Is(err, driver.ErrBadConn) {
1468
1469
1470
1471
1472 db.maybeOpenNewConnections()
1473 db.mu.Unlock()
1474 dc.Close()
1475 return
1476 }
1477 if putConnHook != nil {
1478 putConnHook(db, dc)
1479 }
1480 added := db.putConnDBLocked(dc, nil)
1481 db.mu.Unlock()
1482
1483 if !added {
1484 dc.Close()
1485 return
1486 }
1487 }
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498 func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
1499 if db.closed {
1500 return false
1501 }
1502 if db.maxOpen > 0 && db.numOpen > db.maxOpen {
1503 return false
1504 }
1505 if c := len(db.connRequests); c > 0 {
1506 var req chan connRequest
1507 var reqKey uint64
1508 for reqKey, req = range db.connRequests {
1509 break
1510 }
1511 delete(db.connRequests, reqKey)
1512 if err == nil {
1513 dc.inUse = true
1514 }
1515 req <- connRequest{
1516 conn: dc,
1517 err: err,
1518 }
1519 return true
1520 } else if err == nil && !db.closed {
1521 if db.maxIdleConnsLocked() > len(db.freeConn) {
1522 db.freeConn = append(db.freeConn, dc)
1523 db.startCleanerLocked()
1524 return true
1525 }
1526 db.maxIdleClosed++
1527 }
1528 return false
1529 }
1530
1531
1532
1533
1534 const maxBadConnRetries = 2
1535
1536 func (db *DB) retry(fn func(strategy connReuseStrategy) error) error {
1537 for i := int64(0); i < maxBadConnRetries; i++ {
1538 err := fn(cachedOrNewConn)
1539
1540 if err == nil || !errors.Is(err, driver.ErrBadConn) {
1541 return err
1542 }
1543 }
1544
1545 return fn(alwaysNewConn)
1546 }
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556 func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
1557 var stmt *Stmt
1558 var err error
1559
1560 err = db.retry(func(strategy connReuseStrategy) error {
1561 stmt, err = db.prepare(ctx, query, strategy)
1562 return err
1563 })
1564
1565 return stmt, err
1566 }
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576 func (db *DB) Prepare(query string) (*Stmt, error) {
1577 return db.PrepareContext(context.Background(), query)
1578 }
1579
1580 func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
1581
1582
1583
1584
1585
1586
1587 dc, err := db.conn(ctx, strategy)
1588 if err != nil {
1589 return nil, err
1590 }
1591 return db.prepareDC(ctx, dc, dc.releaseConn, nil, query)
1592 }
1593
1594
1595
1596
1597 func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) {
1598 var ds *driverStmt
1599 var err error
1600 defer func() {
1601 release(err)
1602 }()
1603 withLock(dc, func() {
1604 ds, err = dc.prepareLocked(ctx, cg, query)
1605 })
1606 if err != nil {
1607 return nil, err
1608 }
1609 stmt := &Stmt{
1610 db: db,
1611 query: query,
1612 cg: cg,
1613 cgds: ds,
1614 }
1615
1616
1617
1618
1619 if cg == nil {
1620 stmt.css = []connStmt{{dc, ds}}
1621 stmt.lastNumClosed = db.numClosed.Load()
1622 db.addDep(stmt, stmt)
1623 }
1624 return stmt, nil
1625 }
1626
1627
1628
1629 func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
1630 var res Result
1631 var err error
1632
1633 err = db.retry(func(strategy connReuseStrategy) error {
1634 res, err = db.exec(ctx, query, args, strategy)
1635 return err
1636 })
1637
1638 return res, err
1639 }
1640
1641
1642
1643
1644
1645
1646 func (db *DB) Exec(query string, args ...any) (Result, error) {
1647 return db.ExecContext(context.Background(), query, args...)
1648 }
1649
1650 func (db *DB) exec(ctx context.Context, query string, args []any, strategy connReuseStrategy) (Result, error) {
1651 dc, err := db.conn(ctx, strategy)
1652 if err != nil {
1653 return nil, err
1654 }
1655 return db.execDC(ctx, dc, dc.releaseConn, query, args)
1656 }
1657
1658 func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), query string, args []any) (res Result, err error) {
1659 defer func() {
1660 release(err)
1661 }()
1662 execerCtx, ok := dc.ci.(driver.ExecerContext)
1663 var execer driver.Execer
1664 if !ok {
1665 execer, ok = dc.ci.(driver.Execer)
1666 }
1667 if ok {
1668 var nvdargs []driver.NamedValue
1669 var resi driver.Result
1670 withLock(dc, func() {
1671 nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
1672 if err != nil {
1673 return
1674 }
1675 resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
1676 })
1677 if err != driver.ErrSkip {
1678 if err != nil {
1679 return nil, err
1680 }
1681 return driverResult{dc, resi}, nil
1682 }
1683 }
1684
1685 var si driver.Stmt
1686 withLock(dc, func() {
1687 si, err = ctxDriverPrepare(ctx, dc.ci, query)
1688 })
1689 if err != nil {
1690 return nil, err
1691 }
1692 ds := &driverStmt{Locker: dc, si: si}
1693 defer ds.Close()
1694 return resultFromStatement(ctx, dc.ci, ds, args...)
1695 }
1696
1697
1698
1699 func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
1700 var rows *Rows
1701 var err error
1702
1703 err = db.retry(func(strategy connReuseStrategy) error {
1704 rows, err = db.query(ctx, query, args, strategy)
1705 return err
1706 })
1707
1708 return rows, err
1709 }
1710
1711
1712
1713
1714
1715
1716 func (db *DB) Query(query string, args ...any) (*Rows, error) {
1717 return db.QueryContext(context.Background(), query, args...)
1718 }
1719
1720 func (db *DB) query(ctx context.Context, query string, args []any, strategy connReuseStrategy) (*Rows, error) {
1721 dc, err := db.conn(ctx, strategy)
1722 if err != nil {
1723 return nil, err
1724 }
1725
1726 return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args)
1727 }
1728
1729
1730
1731
1732
1733 func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []any) (*Rows, error) {
1734 queryerCtx, ok := dc.ci.(driver.QueryerContext)
1735 var queryer driver.Queryer
1736 if !ok {
1737 queryer, ok = dc.ci.(driver.Queryer)
1738 }
1739 if ok {
1740 var nvdargs []driver.NamedValue
1741 var rowsi driver.Rows
1742 var err error
1743 withLock(dc, func() {
1744 nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
1745 if err != nil {
1746 return
1747 }
1748 rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
1749 })
1750 if err != driver.ErrSkip {
1751 if err != nil {
1752 releaseConn(err)
1753 return nil, err
1754 }
1755
1756
1757 rows := &Rows{
1758 dc: dc,
1759 releaseConn: releaseConn,
1760 rowsi: rowsi,
1761 }
1762 rows.initContextClose(ctx, txctx)
1763 return rows, nil
1764 }
1765 }
1766
1767 var si driver.Stmt
1768 var err error
1769 withLock(dc, func() {
1770 si, err = ctxDriverPrepare(ctx, dc.ci, query)
1771 })
1772 if err != nil {
1773 releaseConn(err)
1774 return nil, err
1775 }
1776
1777 ds := &driverStmt{Locker: dc, si: si}
1778 rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...)
1779 if err != nil {
1780 ds.Close()
1781 releaseConn(err)
1782 return nil, err
1783 }
1784
1785
1786
1787 rows := &Rows{
1788 dc: dc,
1789 releaseConn: releaseConn,
1790 rowsi: rowsi,
1791 closeStmt: ds,
1792 }
1793 rows.initContextClose(ctx, txctx)
1794 return rows, nil
1795 }
1796
1797
1798
1799
1800
1801
1802
1803 func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
1804 rows, err := db.QueryContext(ctx, query, args...)
1805 return &Row{rows: rows, err: err}
1806 }
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817 func (db *DB) QueryRow(query string, args ...any) *Row {
1818 return db.QueryRowContext(context.Background(), query, args...)
1819 }
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831 func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
1832 var tx *Tx
1833 var err error
1834
1835 err = db.retry(func(strategy connReuseStrategy) error {
1836 tx, err = db.begin(ctx, opts, strategy)
1837 return err
1838 })
1839
1840 return tx, err
1841 }
1842
1843
1844
1845
1846
1847
1848 func (db *DB) Begin() (*Tx, error) {
1849 return db.BeginTx(context.Background(), nil)
1850 }
1851
1852 func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) {
1853 dc, err := db.conn(ctx, strategy)
1854 if err != nil {
1855 return nil, err
1856 }
1857 return db.beginDC(ctx, dc, dc.releaseConn, opts)
1858 }
1859
1860
1861 func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) {
1862 var txi driver.Tx
1863 keepConnOnRollback := false
1864 withLock(dc, func() {
1865 _, hasSessionResetter := dc.ci.(driver.SessionResetter)
1866 _, hasConnectionValidator := dc.ci.(driver.Validator)
1867 keepConnOnRollback = hasSessionResetter && hasConnectionValidator
1868 txi, err = ctxDriverBegin(ctx, opts, dc.ci)
1869 })
1870 if err != nil {
1871 release(err)
1872 return nil, err
1873 }
1874
1875
1876
1877 ctx, cancel := context.WithCancel(ctx)
1878 tx = &Tx{
1879 db: db,
1880 dc: dc,
1881 releaseConn: release,
1882 txi: txi,
1883 cancel: cancel,
1884 keepConnOnRollback: keepConnOnRollback,
1885 ctx: ctx,
1886 }
1887 go tx.awaitDone()
1888 return tx, nil
1889 }
1890
1891
1892 func (db *DB) Driver() driver.Driver {
1893 return db.connector.Driver()
1894 }
1895
1896
1897
1898 var ErrConnDone = errors.New("sql: connection is already closed")
1899
1900
1901
1902
1903
1904
1905
1906
1907 func (db *DB) Conn(ctx context.Context) (*Conn, error) {
1908 var dc *driverConn
1909 var err error
1910
1911 err = db.retry(func(strategy connReuseStrategy) error {
1912 dc, err = db.conn(ctx, strategy)
1913 return err
1914 })
1915
1916 if err != nil {
1917 return nil, err
1918 }
1919
1920 conn := &Conn{
1921 db: db,
1922 dc: dc,
1923 }
1924 return conn, nil
1925 }
1926
1927 type releaseConn func(error)
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938 type Conn struct {
1939 db *DB
1940
1941
1942
1943
1944 closemu sync.RWMutex
1945
1946
1947
1948 dc *driverConn
1949
1950
1951
1952 done atomic.Bool
1953
1954
1955
1956 releaseConnOnce sync.Once
1957 releaseConnCache releaseConn
1958 }
1959
1960
1961
1962 func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) {
1963 if c.done.Load() {
1964 return nil, nil, ErrConnDone
1965 }
1966 c.releaseConnOnce.Do(func() {
1967 c.releaseConnCache = c.closemuRUnlockCondReleaseConn
1968 })
1969 c.closemu.RLock()
1970 return c.dc, c.releaseConnCache, nil
1971 }
1972
1973
1974 func (c *Conn) PingContext(ctx context.Context) error {
1975 dc, release, err := c.grabConn(ctx)
1976 if err != nil {
1977 return err
1978 }
1979 return c.db.pingDC(ctx, dc, release)
1980 }
1981
1982
1983
1984 func (c *Conn) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
1985 dc, release, err := c.grabConn(ctx)
1986 if err != nil {
1987 return nil, err
1988 }
1989 return c.db.execDC(ctx, dc, release, query, args)
1990 }
1991
1992
1993
1994 func (c *Conn) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
1995 dc, release, err := c.grabConn(ctx)
1996 if err != nil {
1997 return nil, err
1998 }
1999 return c.db.queryDC(ctx, nil, dc, release, query, args)
2000 }
2001
2002
2003
2004
2005
2006
2007
2008 func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
2009 rows, err := c.QueryContext(ctx, query, args...)
2010 return &Row{rows: rows, err: err}
2011 }
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021 func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
2022 dc, release, err := c.grabConn(ctx)
2023 if err != nil {
2024 return nil, err
2025 }
2026 return c.db.prepareDC(ctx, dc, release, c, query)
2027 }
2028
2029
2030
2031
2032
2033
2034 func (c *Conn) Raw(f func(driverConn any) error) (err error) {
2035 var dc *driverConn
2036 var release releaseConn
2037
2038
2039 dc, release, err = c.grabConn(nil)
2040 if err != nil {
2041 return
2042 }
2043 fPanic := true
2044 dc.Mutex.Lock()
2045 defer func() {
2046 dc.Mutex.Unlock()
2047
2048
2049
2050
2051 if fPanic {
2052 err = driver.ErrBadConn
2053 }
2054 release(err)
2055 }()
2056 err = f(dc.ci)
2057 fPanic = false
2058
2059 return
2060 }
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072 func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
2073 dc, release, err := c.grabConn(ctx)
2074 if err != nil {
2075 return nil, err
2076 }
2077 return c.db.beginDC(ctx, dc, release, opts)
2078 }
2079
2080
2081
2082 func (c *Conn) closemuRUnlockCondReleaseConn(err error) {
2083 c.closemu.RUnlock()
2084 if errors.Is(err, driver.ErrBadConn) {
2085 c.close(err)
2086 }
2087 }
2088
2089 func (c *Conn) txCtx() context.Context {
2090 return nil
2091 }
2092
2093 func (c *Conn) close(err error) error {
2094 if !c.done.CompareAndSwap(false, true) {
2095 return ErrConnDone
2096 }
2097
2098
2099
2100 c.closemu.Lock()
2101 defer c.closemu.Unlock()
2102
2103 c.dc.releaseConn(err)
2104 c.dc = nil
2105 c.db = nil
2106 return err
2107 }
2108
2109
2110
2111
2112
2113
2114 func (c *Conn) Close() error {
2115 return c.close(nil)
2116 }
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128 type Tx struct {
2129 db *DB
2130
2131
2132
2133
2134 closemu sync.RWMutex
2135
2136
2137
2138 dc *driverConn
2139 txi driver.Tx
2140
2141
2142
2143 releaseConn func(error)
2144
2145
2146
2147
2148 done atomic.Bool
2149
2150
2151
2152
2153 keepConnOnRollback bool
2154
2155
2156
2157 stmts struct {
2158 sync.Mutex
2159 v []*Stmt
2160 }
2161
2162
2163 cancel func()
2164
2165
2166 ctx context.Context
2167 }
2168
2169
2170
2171 func (tx *Tx) awaitDone() {
2172
2173
2174 <-tx.ctx.Done()
2175
2176
2177
2178
2179
2180
2181
2182 discardConnection := !tx.keepConnOnRollback
2183 tx.rollback(discardConnection)
2184 }
2185
2186 func (tx *Tx) isDone() bool {
2187 return tx.done.Load()
2188 }
2189
2190
2191
2192 var ErrTxDone = errors.New("sql: transaction has already been committed or rolled back")
2193
2194
2195
2196
2197 func (tx *Tx) close(err error) {
2198 tx.releaseConn(err)
2199 tx.dc = nil
2200 tx.txi = nil
2201 }
2202
2203
2204
2205 var hookTxGrabConn func()
2206
2207 func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
2208 select {
2209 default:
2210 case <-ctx.Done():
2211 return nil, nil, ctx.Err()
2212 }
2213
2214
2215
2216 tx.closemu.RLock()
2217 if tx.isDone() {
2218 tx.closemu.RUnlock()
2219 return nil, nil, ErrTxDone
2220 }
2221 if hookTxGrabConn != nil {
2222 hookTxGrabConn()
2223 }
2224 return tx.dc, tx.closemuRUnlockRelease, nil
2225 }
2226
2227 func (tx *Tx) txCtx() context.Context {
2228 return tx.ctx
2229 }
2230
2231
2232
2233
2234
2235 func (tx *Tx) closemuRUnlockRelease(error) {
2236 tx.closemu.RUnlock()
2237 }
2238
2239
2240 func (tx *Tx) closePrepared() {
2241 tx.stmts.Lock()
2242 defer tx.stmts.Unlock()
2243 for _, stmt := range tx.stmts.v {
2244 stmt.Close()
2245 }
2246 }
2247
2248
2249 func (tx *Tx) Commit() error {
2250
2251
2252
2253 select {
2254 default:
2255 case <-tx.ctx.Done():
2256 if tx.done.Load() {
2257 return ErrTxDone
2258 }
2259 return tx.ctx.Err()
2260 }
2261 if !tx.done.CompareAndSwap(false, true) {
2262 return ErrTxDone
2263 }
2264
2265
2266
2267
2268
2269 tx.cancel()
2270 tx.closemu.Lock()
2271 tx.closemu.Unlock()
2272
2273 var err error
2274 withLock(tx.dc, func() {
2275 err = tx.txi.Commit()
2276 })
2277 if !errors.Is(err, driver.ErrBadConn) {
2278 tx.closePrepared()
2279 }
2280 tx.close(err)
2281 return err
2282 }
2283
2284 var rollbackHook func()
2285
2286
2287
2288 func (tx *Tx) rollback(discardConn bool) error {
2289 if !tx.done.CompareAndSwap(false, true) {
2290 return ErrTxDone
2291 }
2292
2293 if rollbackHook != nil {
2294 rollbackHook()
2295 }
2296
2297
2298
2299
2300
2301 tx.cancel()
2302 tx.closemu.Lock()
2303 tx.closemu.Unlock()
2304
2305 var err error
2306 withLock(tx.dc, func() {
2307 err = tx.txi.Rollback()
2308 })
2309 if !errors.Is(err, driver.ErrBadConn) {
2310 tx.closePrepared()
2311 }
2312 if discardConn {
2313 err = driver.ErrBadConn
2314 }
2315 tx.close(err)
2316 return err
2317 }
2318
2319
2320 func (tx *Tx) Rollback() error {
2321 return tx.rollback(false)
2322 }
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334 func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
2335 dc, release, err := tx.grabConn(ctx)
2336 if err != nil {
2337 return nil, err
2338 }
2339
2340 stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query)
2341 if err != nil {
2342 return nil, err
2343 }
2344 tx.stmts.Lock()
2345 tx.stmts.v = append(tx.stmts.v, stmt)
2346 tx.stmts.Unlock()
2347 return stmt, nil
2348 }
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359 func (tx *Tx) Prepare(query string) (*Stmt, error) {
2360 return tx.PrepareContext(context.Background(), query)
2361 }
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379 func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
2380 dc, release, err := tx.grabConn(ctx)
2381 if err != nil {
2382 return &Stmt{stickyErr: err}
2383 }
2384 defer release(nil)
2385
2386 if tx.db != stmt.db {
2387 return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
2388 }
2389 var si driver.Stmt
2390 var parentStmt *Stmt
2391 stmt.mu.Lock()
2392 if stmt.closed || stmt.cg != nil {
2393
2394
2395
2396
2397
2398
2399 stmt.mu.Unlock()
2400 withLock(dc, func() {
2401 si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
2402 })
2403 if err != nil {
2404 return &Stmt{stickyErr: err}
2405 }
2406 } else {
2407 stmt.removeClosedStmtLocked()
2408
2409
2410 for _, v := range stmt.css {
2411 if v.dc == dc {
2412 si = v.ds.si
2413 break
2414 }
2415 }
2416
2417 stmt.mu.Unlock()
2418
2419 if si == nil {
2420 var ds *driverStmt
2421 withLock(dc, func() {
2422 ds, err = stmt.prepareOnConnLocked(ctx, dc)
2423 })
2424 if err != nil {
2425 return &Stmt{stickyErr: err}
2426 }
2427 si = ds.si
2428 }
2429 parentStmt = stmt
2430 }
2431
2432 txs := &Stmt{
2433 db: tx.db,
2434 cg: tx,
2435 cgds: &driverStmt{
2436 Locker: dc,
2437 si: si,
2438 },
2439 parentStmt: parentStmt,
2440 query: stmt.query,
2441 }
2442 if parentStmt != nil {
2443 tx.db.addDep(parentStmt, txs)
2444 }
2445 tx.stmts.Lock()
2446 tx.stmts.v = append(tx.stmts.v, txs)
2447 tx.stmts.Unlock()
2448 return txs
2449 }
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467 func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
2468 return tx.StmtContext(context.Background(), stmt)
2469 }
2470
2471
2472
2473 func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
2474 dc, release, err := tx.grabConn(ctx)
2475 if err != nil {
2476 return nil, err
2477 }
2478 return tx.db.execDC(ctx, dc, release, query, args)
2479 }
2480
2481
2482
2483
2484
2485
2486 func (tx *Tx) Exec(query string, args ...any) (Result, error) {
2487 return tx.ExecContext(context.Background(), query, args...)
2488 }
2489
2490
2491 func (tx *Tx) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
2492 dc, release, err := tx.grabConn(ctx)
2493 if err != nil {
2494 return nil, err
2495 }
2496
2497 return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args)
2498 }
2499
2500
2501
2502
2503
2504 func (tx *Tx) Query(query string, args ...any) (*Rows, error) {
2505 return tx.QueryContext(context.Background(), query, args...)
2506 }
2507
2508
2509
2510
2511
2512
2513
2514 func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
2515 rows, err := tx.QueryContext(ctx, query, args...)
2516 return &Row{rows: rows, err: err}
2517 }
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528 func (tx *Tx) QueryRow(query string, args ...any) *Row {
2529 return tx.QueryRowContext(context.Background(), query, args...)
2530 }
2531
2532
2533 type connStmt struct {
2534 dc *driverConn
2535 ds *driverStmt
2536 }
2537
2538
2539
2540 type stmtConnGrabber interface {
2541
2542
2543 grabConn(context.Context) (*driverConn, releaseConn, error)
2544
2545
2546
2547
2548 txCtx() context.Context
2549 }
2550
2551 var (
2552 _ stmtConnGrabber = &Tx{}
2553 _ stmtConnGrabber = &Conn{}
2554 )
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565 type Stmt struct {
2566
2567 db *DB
2568 query string
2569 stickyErr error
2570
2571 closemu sync.RWMutex
2572
2573
2574
2575
2576
2577
2578 cg stmtConnGrabber
2579 cgds *driverStmt
2580
2581
2582
2583
2584
2585
2586
2587 parentStmt *Stmt
2588
2589 mu sync.Mutex
2590 closed bool
2591
2592
2593
2594
2595
2596 css []connStmt
2597
2598
2599
2600 lastNumClosed uint64
2601 }
2602
2603
2604
2605 func (s *Stmt) ExecContext(ctx context.Context, args ...any) (Result, error) {
2606 s.closemu.RLock()
2607 defer s.closemu.RUnlock()
2608
2609 var res Result
2610 err := s.db.retry(func(strategy connReuseStrategy) error {
2611 dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
2612 if err != nil {
2613 return err
2614 }
2615
2616 res, err = resultFromStatement(ctx, dc.ci, ds, args...)
2617 releaseConn(err)
2618 return err
2619 })
2620
2621 return res, err
2622 }
2623
2624
2625
2626
2627
2628
2629 func (s *Stmt) Exec(args ...any) (Result, error) {
2630 return s.ExecContext(context.Background(), args...)
2631 }
2632
2633 func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (Result, error) {
2634 ds.Lock()
2635 defer ds.Unlock()
2636
2637 dargs, err := driverArgsConnLocked(ci, ds, args)
2638 if err != nil {
2639 return nil, err
2640 }
2641
2642 resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
2643 if err != nil {
2644 return nil, err
2645 }
2646 return driverResult{ds.Locker, resi}, nil
2647 }
2648
2649
2650
2651
2652
2653 func (s *Stmt) removeClosedStmtLocked() {
2654 t := len(s.css)/2 + 1
2655 if t > 10 {
2656 t = 10
2657 }
2658 dbClosed := s.db.numClosed.Load()
2659 if dbClosed-s.lastNumClosed < uint64(t) {
2660 return
2661 }
2662
2663 s.db.mu.Lock()
2664 for i := 0; i < len(s.css); i++ {
2665 if s.css[i].dc.dbmuClosed {
2666 s.css[i] = s.css[len(s.css)-1]
2667 s.css = s.css[:len(s.css)-1]
2668 i--
2669 }
2670 }
2671 s.db.mu.Unlock()
2672 s.lastNumClosed = dbClosed
2673 }
2674
2675
2676
2677
2678 func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
2679 if err = s.stickyErr; err != nil {
2680 return
2681 }
2682 s.mu.Lock()
2683 if s.closed {
2684 s.mu.Unlock()
2685 err = errors.New("sql: statement is closed")
2686 return
2687 }
2688
2689
2690
2691 if s.cg != nil {
2692 s.mu.Unlock()
2693 dc, releaseConn, err = s.cg.grabConn(ctx)
2694 if err != nil {
2695 return
2696 }
2697 return dc, releaseConn, s.cgds, nil
2698 }
2699
2700 s.removeClosedStmtLocked()
2701 s.mu.Unlock()
2702
2703 dc, err = s.db.conn(ctx, strategy)
2704 if err != nil {
2705 return nil, nil, nil, err
2706 }
2707
2708 s.mu.Lock()
2709 for _, v := range s.css {
2710 if v.dc == dc {
2711 s.mu.Unlock()
2712 return dc, dc.releaseConn, v.ds, nil
2713 }
2714 }
2715 s.mu.Unlock()
2716
2717
2718 withLock(dc, func() {
2719 ds, err = s.prepareOnConnLocked(ctx, dc)
2720 })
2721 if err != nil {
2722 dc.releaseConn(err)
2723 return nil, nil, nil, err
2724 }
2725
2726 return dc, dc.releaseConn, ds, nil
2727 }
2728
2729
2730
2731 func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
2732 si, err := dc.prepareLocked(ctx, s.cg, s.query)
2733 if err != nil {
2734 return nil, err
2735 }
2736 cs := connStmt{dc, si}
2737 s.mu.Lock()
2738 s.css = append(s.css, cs)
2739 s.mu.Unlock()
2740 return cs.ds, nil
2741 }
2742
2743
2744
2745 func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*Rows, error) {
2746 s.closemu.RLock()
2747 defer s.closemu.RUnlock()
2748
2749 var rowsi driver.Rows
2750 var rows *Rows
2751
2752 err := s.db.retry(func(strategy connReuseStrategy) error {
2753 dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
2754 if err != nil {
2755 return err
2756 }
2757
2758 rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...)
2759 if err == nil {
2760
2761
2762 rows = &Rows{
2763 dc: dc,
2764 rowsi: rowsi,
2765
2766 }
2767
2768
2769 s.db.addDep(s, rows)
2770
2771
2772
2773 rows.releaseConn = func(err error) {
2774 releaseConn(err)
2775 s.db.removeDep(s, rows)
2776 }
2777 var txctx context.Context
2778 if s.cg != nil {
2779 txctx = s.cg.txCtx()
2780 }
2781 rows.initContextClose(ctx, txctx)
2782 return nil
2783 }
2784
2785 releaseConn(err)
2786 return err
2787 })
2788
2789 return rows, err
2790 }
2791
2792
2793
2794
2795
2796
2797 func (s *Stmt) Query(args ...any) (*Rows, error) {
2798 return s.QueryContext(context.Background(), args...)
2799 }
2800
2801 func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (driver.Rows, error) {
2802 ds.Lock()
2803 defer ds.Unlock()
2804 dargs, err := driverArgsConnLocked(ci, ds, args)
2805 if err != nil {
2806 return nil, err
2807 }
2808 return ctxDriverStmtQuery(ctx, ds.si, dargs)
2809 }
2810
2811
2812
2813
2814
2815
2816
2817 func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *Row {
2818 rows, err := s.QueryContext(ctx, args...)
2819 if err != nil {
2820 return &Row{err: err}
2821 }
2822 return &Row{rows: rows}
2823 }
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839 func (s *Stmt) QueryRow(args ...any) *Row {
2840 return s.QueryRowContext(context.Background(), args...)
2841 }
2842
2843
2844 func (s *Stmt) Close() error {
2845 s.closemu.Lock()
2846 defer s.closemu.Unlock()
2847
2848 if s.stickyErr != nil {
2849 return s.stickyErr
2850 }
2851 s.mu.Lock()
2852 if s.closed {
2853 s.mu.Unlock()
2854 return nil
2855 }
2856 s.closed = true
2857 txds := s.cgds
2858 s.cgds = nil
2859
2860 s.mu.Unlock()
2861
2862 if s.cg == nil {
2863 return s.db.removeDep(s, s)
2864 }
2865
2866 if s.parentStmt != nil {
2867
2868
2869 return s.db.removeDep(s.parentStmt, s)
2870 }
2871 return txds.Close()
2872 }
2873
2874 func (s *Stmt) finalClose() error {
2875 s.mu.Lock()
2876 defer s.mu.Unlock()
2877 if s.css != nil {
2878 for _, v := range s.css {
2879 s.db.noteUnusedDriverStatement(v.dc, v.ds)
2880 v.dc.removeOpenStmt(v.ds)
2881 }
2882 s.css = nil
2883 }
2884 return nil
2885 }
2886
2887
2888
2889 type Rows struct {
2890 dc *driverConn
2891 releaseConn func(error)
2892 rowsi driver.Rows
2893 cancel func()
2894 closeStmt *driverStmt
2895
2896 contextDone atomic.Pointer[error]
2897
2898
2899
2900
2901
2902
2903 closemu sync.RWMutex
2904 closed bool
2905 lasterr error
2906
2907
2908
2909 lastcols []driver.Value
2910
2911
2912
2913
2914
2915
2916
2917
2918 closemuScanHold bool
2919
2920
2921
2922
2923
2924 hitEOF bool
2925 }
2926
2927
2928
2929 func (rs *Rows) lasterrOrErrLocked(err error) error {
2930 if rs.lasterr != nil && rs.lasterr != io.EOF {
2931 return rs.lasterr
2932 }
2933 return err
2934 }
2935
2936
2937
2938 var bypassRowsAwaitDone = false
2939
2940 func (rs *Rows) initContextClose(ctx, txctx context.Context) {
2941 if ctx.Done() == nil && (txctx == nil || txctx.Done() == nil) {
2942 return
2943 }
2944 if bypassRowsAwaitDone {
2945 return
2946 }
2947 closectx, cancel := context.WithCancel(ctx)
2948 rs.cancel = cancel
2949 go rs.awaitDone(ctx, txctx, closectx)
2950 }
2951
2952
2953
2954
2955
2956
2957 func (rs *Rows) awaitDone(ctx, txctx, closectx context.Context) {
2958 var txctxDone <-chan struct{}
2959 if txctx != nil {
2960 txctxDone = txctx.Done()
2961 }
2962 select {
2963 case <-ctx.Done():
2964 err := ctx.Err()
2965 rs.contextDone.Store(&err)
2966 case <-txctxDone:
2967 err := txctx.Err()
2968 rs.contextDone.Store(&err)
2969 case <-closectx.Done():
2970
2971
2972 }
2973 rs.close(ctx.Err())
2974 }
2975
2976
2977
2978
2979
2980
2981
2982 func (rs *Rows) Next() bool {
2983
2984
2985
2986 rs.closemuRUnlockIfHeldByScan()
2987
2988 if rs.contextDone.Load() != nil {
2989 return false
2990 }
2991
2992 var doClose, ok bool
2993 withLock(rs.closemu.RLocker(), func() {
2994 doClose, ok = rs.nextLocked()
2995 })
2996 if doClose {
2997 rs.Close()
2998 }
2999 if doClose && !ok {
3000 rs.hitEOF = true
3001 }
3002 return ok
3003 }
3004
3005 func (rs *Rows) nextLocked() (doClose, ok bool) {
3006 if rs.closed {
3007 return false, false
3008 }
3009
3010
3011
3012 rs.dc.Lock()
3013 defer rs.dc.Unlock()
3014
3015 if rs.lastcols == nil {
3016 rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
3017 }
3018
3019 rs.lasterr = rs.rowsi.Next(rs.lastcols)
3020 if rs.lasterr != nil {
3021
3022 if rs.lasterr != io.EOF {
3023 return true, false
3024 }
3025 nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
3026 if !ok {
3027 return true, false
3028 }
3029
3030
3031
3032 if !nextResultSet.HasNextResultSet() {
3033 doClose = true
3034 }
3035 return doClose, false
3036 }
3037 return false, true
3038 }
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048 func (rs *Rows) NextResultSet() bool {
3049
3050
3051
3052 rs.closemuRUnlockIfHeldByScan()
3053
3054 var doClose bool
3055 defer func() {
3056 if doClose {
3057 rs.Close()
3058 }
3059 }()
3060 rs.closemu.RLock()
3061 defer rs.closemu.RUnlock()
3062
3063 if rs.closed {
3064 return false
3065 }
3066
3067 rs.lastcols = nil
3068 nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
3069 if !ok {
3070 doClose = true
3071 return false
3072 }
3073
3074
3075
3076 rs.dc.Lock()
3077 defer rs.dc.Unlock()
3078
3079 rs.lasterr = nextResultSet.NextResultSet()
3080 if rs.lasterr != nil {
3081 doClose = true
3082 return false
3083 }
3084 return true
3085 }
3086
3087
3088
3089 func (rs *Rows) Err() error {
3090
3091
3092
3093
3094 if !rs.hitEOF {
3095 if errp := rs.contextDone.Load(); errp != nil {
3096 return *errp
3097 }
3098 }
3099
3100 rs.closemu.RLock()
3101 defer rs.closemu.RUnlock()
3102 return rs.lasterrOrErrLocked(nil)
3103 }
3104
3105 var errRowsClosed = errors.New("sql: Rows are closed")
3106 var errNoRows = errors.New("sql: no Rows available")
3107
3108
3109
3110 func (rs *Rows) Columns() ([]string, error) {
3111 rs.closemu.RLock()
3112 defer rs.closemu.RUnlock()
3113 if rs.closed {
3114 return nil, rs.lasterrOrErrLocked(errRowsClosed)
3115 }
3116 if rs.rowsi == nil {
3117 return nil, rs.lasterrOrErrLocked(errNoRows)
3118 }
3119 rs.dc.Lock()
3120 defer rs.dc.Unlock()
3121
3122 return rs.rowsi.Columns(), nil
3123 }
3124
3125
3126
3127 func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
3128 rs.closemu.RLock()
3129 defer rs.closemu.RUnlock()
3130 if rs.closed {
3131 return nil, rs.lasterrOrErrLocked(errRowsClosed)
3132 }
3133 if rs.rowsi == nil {
3134 return nil, rs.lasterrOrErrLocked(errNoRows)
3135 }
3136 rs.dc.Lock()
3137 defer rs.dc.Unlock()
3138
3139 return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
3140 }
3141
3142
3143 type ColumnType struct {
3144 name string
3145
3146 hasNullable bool
3147 hasLength bool
3148 hasPrecisionScale bool
3149
3150 nullable bool
3151 length int64
3152 databaseType string
3153 precision int64
3154 scale int64
3155 scanType reflect.Type
3156 }
3157
3158
3159 func (ci *ColumnType) Name() string {
3160 return ci.name
3161 }
3162
3163
3164
3165
3166
3167
3168 func (ci *ColumnType) Length() (length int64, ok bool) {
3169 return ci.length, ci.hasLength
3170 }
3171
3172
3173
3174 func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
3175 return ci.precision, ci.scale, ci.hasPrecisionScale
3176 }
3177
3178
3179
3180
3181 func (ci *ColumnType) ScanType() reflect.Type {
3182 return ci.scanType
3183 }
3184
3185
3186
3187 func (ci *ColumnType) Nullable() (nullable, ok bool) {
3188 return ci.nullable, ci.hasNullable
3189 }
3190
3191
3192
3193
3194
3195
3196
3197 func (ci *ColumnType) DatabaseTypeName() string {
3198 return ci.databaseType
3199 }
3200
3201 func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
3202 names := rowsi.Columns()
3203
3204 list := make([]*ColumnType, len(names))
3205 for i := range list {
3206 ci := &ColumnType{
3207 name: names[i],
3208 }
3209 list[i] = ci
3210
3211 if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok {
3212 ci.scanType = prop.ColumnTypeScanType(i)
3213 } else {
3214 ci.scanType = reflect.TypeOf(new(any)).Elem()
3215 }
3216 if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok {
3217 ci.databaseType = prop.ColumnTypeDatabaseTypeName(i)
3218 }
3219 if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok {
3220 ci.length, ci.hasLength = prop.ColumnTypeLength(i)
3221 }
3222 if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok {
3223 ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i)
3224 }
3225 if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok {
3226 ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i)
3227 }
3228 }
3229 return list
3230 }
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292 func (rs *Rows) Scan(dest ...any) error {
3293 if rs.closemuScanHold {
3294
3295
3296 return fmt.Errorf("sql: Scan called without calling Next (closemuScanHold)")
3297 }
3298 rs.closemu.RLock()
3299
3300 if rs.lasterr != nil && rs.lasterr != io.EOF {
3301 rs.closemu.RUnlock()
3302 return rs.lasterr
3303 }
3304 if rs.closed {
3305 err := rs.lasterrOrErrLocked(errRowsClosed)
3306 rs.closemu.RUnlock()
3307 return err
3308 }
3309
3310 if scanArgsContainRawBytes(dest) {
3311 rs.closemuScanHold = true
3312 } else {
3313 rs.closemu.RUnlock()
3314 }
3315
3316 if rs.lastcols == nil {
3317 rs.closemuRUnlockIfHeldByScan()
3318 return errors.New("sql: Scan called without calling Next")
3319 }
3320 if len(dest) != len(rs.lastcols) {
3321 rs.closemuRUnlockIfHeldByScan()
3322 return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
3323 }
3324
3325 for i, sv := range rs.lastcols {
3326 err := convertAssignRows(dest[i], sv, rs)
3327 if err != nil {
3328 rs.closemuRUnlockIfHeldByScan()
3329 return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
3330 }
3331 }
3332 return nil
3333 }
3334
3335
3336
3337 func (rs *Rows) closemuRUnlockIfHeldByScan() {
3338 if rs.closemuScanHold {
3339 rs.closemuScanHold = false
3340 rs.closemu.RUnlock()
3341 }
3342 }
3343
3344 func scanArgsContainRawBytes(args []any) bool {
3345 for _, a := range args {
3346 if _, ok := a.(*RawBytes); ok {
3347 return true
3348 }
3349 }
3350 return false
3351 }
3352
3353
3354
3355 var rowsCloseHook = func() func(*Rows, *error) { return nil }
3356
3357
3358
3359
3360
3361 func (rs *Rows) Close() error {
3362
3363
3364
3365 rs.closemuRUnlockIfHeldByScan()
3366
3367 return rs.close(nil)
3368 }
3369
3370 func (rs *Rows) close(err error) error {
3371 rs.closemu.Lock()
3372 defer rs.closemu.Unlock()
3373
3374 if rs.closed {
3375 return nil
3376 }
3377 rs.closed = true
3378
3379 if rs.lasterr == nil {
3380 rs.lasterr = err
3381 }
3382
3383 withLock(rs.dc, func() {
3384 err = rs.rowsi.Close()
3385 })
3386 if fn := rowsCloseHook(); fn != nil {
3387 fn(rs, &err)
3388 }
3389 if rs.cancel != nil {
3390 rs.cancel()
3391 }
3392
3393 if rs.closeStmt != nil {
3394 rs.closeStmt.Close()
3395 }
3396 rs.releaseConn(err)
3397
3398 rs.lasterr = rs.lasterrOrErrLocked(err)
3399 return err
3400 }
3401
3402
3403 type Row struct {
3404
3405 err error
3406 rows *Rows
3407 }
3408
3409
3410
3411
3412
3413
3414 func (r *Row) Scan(dest ...any) error {
3415 if r.err != nil {
3416 return r.err
3417 }
3418
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
3431
3432 defer r.rows.Close()
3433 for _, dp := range dest {
3434 if _, ok := dp.(*RawBytes); ok {
3435 return errors.New("sql: RawBytes isn't allowed on Row.Scan")
3436 }
3437 }
3438
3439 if !r.rows.Next() {
3440 if err := r.rows.Err(); err != nil {
3441 return err
3442 }
3443 return ErrNoRows
3444 }
3445 err := r.rows.Scan(dest...)
3446 if err != nil {
3447 return err
3448 }
3449
3450 return r.rows.Close()
3451 }
3452
3453
3454
3455
3456
3457 func (r *Row) Err() error {
3458 return r.err
3459 }
3460
3461
3462 type Result interface {
3463
3464
3465
3466
3467
3468 LastInsertId() (int64, error)
3469
3470
3471
3472
3473 RowsAffected() (int64, error)
3474 }
3475
3476 type driverResult struct {
3477 sync.Locker
3478 resi driver.Result
3479 }
3480
3481 func (dr driverResult) LastInsertId() (int64, error) {
3482 dr.Lock()
3483 defer dr.Unlock()
3484 return dr.resi.LastInsertId()
3485 }
3486
3487 func (dr driverResult) RowsAffected() (int64, error) {
3488 dr.Lock()
3489 defer dr.Unlock()
3490 return dr.resi.RowsAffected()
3491 }
3492
3493 func stack() string {
3494 var buf [2 << 10]byte
3495 return string(buf[:runtime.Stack(buf[:], false)])
3496 }
3497
3498
3499 func withLock(lk sync.Locker, fn func()) {
3500 lk.Lock()
3501 defer lk.Unlock()
3502 fn()
3503 }
3504
View as plain text