1
2
3
4
5
6
7 package sql
8
9 import (
10 "bytes"
11 "database/sql/driver"
12 "errors"
13 "fmt"
14 "reflect"
15 "strconv"
16 "time"
17 "unicode"
18 "unicode/utf8"
19 )
20
21 var errNilPtr = errors.New("destination pointer is nil")
22
23 func describeNamedValue(nv *driver.NamedValue) string {
24 if len(nv.Name) == 0 {
25 return fmt.Sprintf("$%d", nv.Ordinal)
26 }
27 return fmt.Sprintf("with name %q", nv.Name)
28 }
29
30 func validateNamedValueName(name string) error {
31 if len(name) == 0 {
32 return nil
33 }
34 r, _ := utf8.DecodeRuneInString(name)
35 if unicode.IsLetter(r) {
36 return nil
37 }
38 return fmt.Errorf("name %q does not begin with a letter", name)
39 }
40
41
42
43
44 type ccChecker struct {
45 cci driver.ColumnConverter
46 want int
47 }
48
49 func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
50 if c.cci == nil {
51 return driver.ErrSkip
52 }
53
54
55
56 index := nv.Ordinal - 1
57 if c.want <= index {
58 return nil
59 }
60
61
62
63
64 if vr, ok := nv.Value.(driver.Valuer); ok {
65 sv, err := callValuerValue(vr)
66 if err != nil {
67 return err
68 }
69 if !driver.IsValue(sv) {
70 return fmt.Errorf("non-subset type %T returned from Value", sv)
71 }
72 nv.Value = sv
73 }
74
75
76
77
78
79
80
81
82 var err error
83 arg := nv.Value
84 nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
85 if err != nil {
86 return err
87 }
88 if !driver.IsValue(nv.Value) {
89 return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value)
90 }
91 return nil
92 }
93
94
95
96
97 func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
98 nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
99 return err
100 }
101
102
103
104
105
106
107
108 func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []any) ([]driver.NamedValue, error) {
109 nvargs := make([]driver.NamedValue, len(args))
110
111
112
113
114 want := -1
115
116 var si driver.Stmt
117 var cc ccChecker
118 if ds != nil {
119 si = ds.si
120 want = ds.si.NumInput()
121 cc.want = want
122 }
123
124
125
126
127
128 nvc, ok := si.(driver.NamedValueChecker)
129 if !ok {
130 nvc, ok = ci.(driver.NamedValueChecker)
131 }
132 cci, ok := si.(driver.ColumnConverter)
133 if ok {
134 cc.cci = cci
135 }
136
137
138
139
140
141
142 var err error
143 var n int
144 for _, arg := range args {
145 nv := &nvargs[n]
146 if np, ok := arg.(NamedArg); ok {
147 if err = validateNamedValueName(np.Name); err != nil {
148 return nil, err
149 }
150 arg = np.Value
151 nv.Name = np.Name
152 }
153 nv.Ordinal = n + 1
154 nv.Value = arg
155
156
157
158
159
160
161
162
163
164
165
166
167 checker := defaultCheckNamedValue
168 nextCC := false
169 switch {
170 case nvc != nil:
171 nextCC = cci != nil
172 checker = nvc.CheckNamedValue
173 case cci != nil:
174 checker = cc.CheckNamedValue
175 }
176
177 nextCheck:
178 err = checker(nv)
179 switch err {
180 case nil:
181 n++
182 continue
183 case driver.ErrRemoveArgument:
184 nvargs = nvargs[:len(nvargs)-1]
185 continue
186 case driver.ErrSkip:
187 if nextCC {
188 nextCC = false
189 checker = cc.CheckNamedValue
190 } else {
191 checker = defaultCheckNamedValue
192 }
193 goto nextCheck
194 default:
195 return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err)
196 }
197 }
198
199
200
201 if want != -1 && len(nvargs) != want {
202 return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
203 }
204
205 return nvargs, nil
206
207 }
208
209
210
211 func convertAssign(dest, src any) error {
212 return convertAssignRows(dest, src, nil)
213 }
214
215
216
217
218
219
220 func convertAssignRows(dest, src any, rows *Rows) error {
221
222 switch s := src.(type) {
223 case string:
224 switch d := dest.(type) {
225 case *string:
226 if d == nil {
227 return errNilPtr
228 }
229 *d = s
230 return nil
231 case *[]byte:
232 if d == nil {
233 return errNilPtr
234 }
235 *d = []byte(s)
236 return nil
237 case *RawBytes:
238 if d == nil {
239 return errNilPtr
240 }
241 *d = append((*d)[:0], s...)
242 return nil
243 }
244 case []byte:
245 switch d := dest.(type) {
246 case *string:
247 if d == nil {
248 return errNilPtr
249 }
250 *d = string(s)
251 return nil
252 case *any:
253 if d == nil {
254 return errNilPtr
255 }
256 *d = bytes.Clone(s)
257 return nil
258 case *[]byte:
259 if d == nil {
260 return errNilPtr
261 }
262 *d = bytes.Clone(s)
263 return nil
264 case *RawBytes:
265 if d == nil {
266 return errNilPtr
267 }
268 *d = s
269 return nil
270 }
271 case time.Time:
272 switch d := dest.(type) {
273 case *time.Time:
274 *d = s
275 return nil
276 case *string:
277 *d = s.Format(time.RFC3339Nano)
278 return nil
279 case *[]byte:
280 if d == nil {
281 return errNilPtr
282 }
283 *d = []byte(s.Format(time.RFC3339Nano))
284 return nil
285 case *RawBytes:
286 if d == nil {
287 return errNilPtr
288 }
289 *d = s.AppendFormat((*d)[:0], time.RFC3339Nano)
290 return nil
291 }
292 case decimalDecompose:
293 switch d := dest.(type) {
294 case decimalCompose:
295 return d.Compose(s.Decompose(nil))
296 }
297 case nil:
298 switch d := dest.(type) {
299 case *any:
300 if d == nil {
301 return errNilPtr
302 }
303 *d = nil
304 return nil
305 case *[]byte:
306 if d == nil {
307 return errNilPtr
308 }
309 *d = nil
310 return nil
311 case *RawBytes:
312 if d == nil {
313 return errNilPtr
314 }
315 *d = nil
316 return nil
317 }
318
319 case driver.Rows:
320 switch d := dest.(type) {
321 case *Rows:
322 if d == nil {
323 return errNilPtr
324 }
325 if rows == nil {
326 return errors.New("invalid context to convert cursor rows, missing parent *Rows")
327 }
328 rows.closemu.Lock()
329 *d = Rows{
330 dc: rows.dc,
331 releaseConn: func(error) {},
332 rowsi: s,
333 }
334
335 parentCancel := rows.cancel
336 rows.cancel = func() {
337
338
339 d.close(rows.lasterr)
340 if parentCancel != nil {
341 parentCancel()
342 }
343 }
344 rows.closemu.Unlock()
345 return nil
346 }
347 }
348
349 var sv reflect.Value
350
351 switch d := dest.(type) {
352 case *string:
353 sv = reflect.ValueOf(src)
354 switch sv.Kind() {
355 case reflect.Bool,
356 reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
357 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
358 reflect.Float32, reflect.Float64:
359 *d = asString(src)
360 return nil
361 }
362 case *[]byte:
363 sv = reflect.ValueOf(src)
364 if b, ok := asBytes(nil, sv); ok {
365 *d = b
366 return nil
367 }
368 case *RawBytes:
369 sv = reflect.ValueOf(src)
370 if b, ok := asBytes([]byte(*d)[:0], sv); ok {
371 *d = RawBytes(b)
372 return nil
373 }
374 case *bool:
375 bv, err := driver.Bool.ConvertValue(src)
376 if err == nil {
377 *d = bv.(bool)
378 }
379 return err
380 case *any:
381 *d = src
382 return nil
383 }
384
385 if scanner, ok := dest.(Scanner); ok {
386 return scanner.Scan(src)
387 }
388
389 dpv := reflect.ValueOf(dest)
390 if dpv.Kind() != reflect.Pointer {
391 return errors.New("destination not a pointer")
392 }
393 if dpv.IsNil() {
394 return errNilPtr
395 }
396
397 if !sv.IsValid() {
398 sv = reflect.ValueOf(src)
399 }
400
401 dv := reflect.Indirect(dpv)
402 if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
403 switch b := src.(type) {
404 case []byte:
405 dv.Set(reflect.ValueOf(bytes.Clone(b)))
406 default:
407 dv.Set(sv)
408 }
409 return nil
410 }
411
412 if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
413 dv.Set(sv.Convert(dv.Type()))
414 return nil
415 }
416
417
418
419
420
421
422 switch dv.Kind() {
423 case reflect.Pointer:
424 if src == nil {
425 dv.Set(reflect.Zero(dv.Type()))
426 return nil
427 }
428 dv.Set(reflect.New(dv.Type().Elem()))
429 return convertAssignRows(dv.Interface(), src, rows)
430 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
431 if src == nil {
432 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
433 }
434 s := asString(src)
435 i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
436 if err != nil {
437 err = strconvErr(err)
438 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
439 }
440 dv.SetInt(i64)
441 return nil
442 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
443 if src == nil {
444 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
445 }
446 s := asString(src)
447 u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
448 if err != nil {
449 err = strconvErr(err)
450 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
451 }
452 dv.SetUint(u64)
453 return nil
454 case reflect.Float32, reflect.Float64:
455 if src == nil {
456 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
457 }
458 s := asString(src)
459 f64, err := strconv.ParseFloat(s, dv.Type().Bits())
460 if err != nil {
461 err = strconvErr(err)
462 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
463 }
464 dv.SetFloat(f64)
465 return nil
466 case reflect.String:
467 if src == nil {
468 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
469 }
470 switch v := src.(type) {
471 case string:
472 dv.SetString(v)
473 return nil
474 case []byte:
475 dv.SetString(string(v))
476 return nil
477 }
478 }
479
480 return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
481 }
482
483 func strconvErr(err error) error {
484 if ne, ok := err.(*strconv.NumError); ok {
485 return ne.Err
486 }
487 return err
488 }
489
490 func asString(src any) string {
491 switch v := src.(type) {
492 case string:
493 return v
494 case []byte:
495 return string(v)
496 }
497 rv := reflect.ValueOf(src)
498 switch rv.Kind() {
499 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
500 return strconv.FormatInt(rv.Int(), 10)
501 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
502 return strconv.FormatUint(rv.Uint(), 10)
503 case reflect.Float64:
504 return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
505 case reflect.Float32:
506 return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
507 case reflect.Bool:
508 return strconv.FormatBool(rv.Bool())
509 }
510 return fmt.Sprintf("%v", src)
511 }
512
513 func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
514 switch rv.Kind() {
515 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
516 return strconv.AppendInt(buf, rv.Int(), 10), true
517 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
518 return strconv.AppendUint(buf, rv.Uint(), 10), true
519 case reflect.Float32:
520 return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
521 case reflect.Float64:
522 return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
523 case reflect.Bool:
524 return strconv.AppendBool(buf, rv.Bool()), true
525 case reflect.String:
526 s := rv.String()
527 return append(buf, s...), true
528 }
529 return
530 }
531
532 var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
533
534
535
536
537
538
539
540
541
542
543
544
545 func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
546 if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer &&
547 rv.IsNil() &&
548 rv.Type().Elem().Implements(valuerReflectType) {
549 return nil, nil
550 }
551 return vr.Value()
552 }
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575 type decimal interface {
576 decimalDecompose
577 decimalCompose
578 }
579
580 type decimalDecompose interface {
581
582
583
584 Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
585 }
586
587 type decimalCompose interface {
588
589
590 Compose(form byte, negative bool, coefficient []byte, exponent int32) error
591 }
592
View as plain text