1
2
3
4
5 package textproto
6
7 import (
8 "bufio"
9 "bytes"
10 "errors"
11 "fmt"
12 "io"
13 "math"
14 "strconv"
15 "strings"
16 "sync"
17 )
18
19
20
21 type Reader struct {
22 R *bufio.Reader
23 dot *dotReader
24 buf []byte
25 }
26
27
28
29
30
31
32 func NewReader(r *bufio.Reader) *Reader {
33 return &Reader{R: r}
34 }
35
36
37
38 func (r *Reader) ReadLine() (string, error) {
39 line, err := r.readLineSlice()
40 return string(line), err
41 }
42
43
44 func (r *Reader) ReadLineBytes() ([]byte, error) {
45 line, err := r.readLineSlice()
46 if line != nil {
47 line = bytes.Clone(line)
48 }
49 return line, err
50 }
51
52 func (r *Reader) readLineSlice() ([]byte, error) {
53 r.closeDot()
54 var line []byte
55 for {
56 l, more, err := r.R.ReadLine()
57 if err != nil {
58 return nil, err
59 }
60
61 if line == nil && !more {
62 return l, nil
63 }
64 line = append(line, l...)
65 if !more {
66 break
67 }
68 }
69 return line, nil
70 }
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90 func (r *Reader) ReadContinuedLine() (string, error) {
91 line, err := r.readContinuedLineSlice(noValidation)
92 return string(line), err
93 }
94
95
96
97 func trim(s []byte) []byte {
98 i := 0
99 for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
100 i++
101 }
102 n := len(s)
103 for n > i && (s[n-1] == ' ' || s[n-1] == '\t') {
104 n--
105 }
106 return s[i:n]
107 }
108
109
110
111 func (r *Reader) ReadContinuedLineBytes() ([]byte, error) {
112 line, err := r.readContinuedLineSlice(noValidation)
113 if line != nil {
114 line = bytes.Clone(line)
115 }
116 return line, err
117 }
118
119
120
121
122
123 func (r *Reader) readContinuedLineSlice(validateFirstLine func([]byte) error) ([]byte, error) {
124 if validateFirstLine == nil {
125 return nil, fmt.Errorf("missing validateFirstLine func")
126 }
127
128
129 line, err := r.readLineSlice()
130 if err != nil {
131 return nil, err
132 }
133 if len(line) == 0 {
134 return line, nil
135 }
136
137 if err := validateFirstLine(line); err != nil {
138 return nil, err
139 }
140
141
142
143
144
145 if r.R.Buffered() > 1 {
146 peek, _ := r.R.Peek(2)
147 if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') ||
148 len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' {
149 return trim(line), nil
150 }
151 }
152
153
154
155 r.buf = append(r.buf[:0], trim(line)...)
156
157
158 for r.skipSpace() > 0 {
159 line, err := r.readLineSlice()
160 if err != nil {
161 break
162 }
163 r.buf = append(r.buf, ' ')
164 r.buf = append(r.buf, trim(line)...)
165 }
166 return r.buf, nil
167 }
168
169
170 func (r *Reader) skipSpace() int {
171 n := 0
172 for {
173 c, err := r.R.ReadByte()
174 if err != nil {
175
176 break
177 }
178 if c != ' ' && c != '\t' {
179 r.R.UnreadByte()
180 break
181 }
182 n++
183 }
184 return n
185 }
186
187 func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
188 line, err := r.ReadLine()
189 if err != nil {
190 return
191 }
192 return parseCodeLine(line, expectCode)
193 }
194
195 func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) {
196 if len(line) < 4 || line[3] != ' ' && line[3] != '-' {
197 err = ProtocolError("short response: " + line)
198 return
199 }
200 continued = line[3] == '-'
201 code, err = strconv.Atoi(line[0:3])
202 if err != nil || code < 100 {
203 err = ProtocolError("invalid response code: " + line)
204 return
205 }
206 message = line[4:]
207 if 1 <= expectCode && expectCode < 10 && code/100 != expectCode ||
208 10 <= expectCode && expectCode < 100 && code/10 != expectCode ||
209 100 <= expectCode && expectCode < 1000 && code != expectCode {
210 err = &Error{code, message}
211 }
212 return
213 }
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232 func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) {
233 code, continued, message, err := r.readCodeLine(expectCode)
234 if err == nil && continued {
235 err = ProtocolError("unexpected multi-line response: " + message)
236 }
237 return
238 }
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266 func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) {
267 code, continued, message, err := r.readCodeLine(expectCode)
268 multi := continued
269 for continued {
270 line, err := r.ReadLine()
271 if err != nil {
272 return 0, "", err
273 }
274
275 var code2 int
276 var moreMessage string
277 code2, continued, moreMessage, err = parseCodeLine(line, 0)
278 if err != nil || code2 != code {
279 message += "\n" + strings.TrimRight(line, "\r\n")
280 continued = true
281 continue
282 }
283 message += "\n" + moreMessage
284 }
285 if err != nil && multi && message != "" {
286
287 err = &Error{code, message}
288 }
289 return
290 }
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308 func (r *Reader) DotReader() io.Reader {
309 r.closeDot()
310 r.dot = &dotReader{r: r}
311 return r.dot
312 }
313
314 type dotReader struct {
315 r *Reader
316 state int
317 }
318
319
320 func (d *dotReader) Read(b []byte) (n int, err error) {
321
322
323
324 const (
325 stateBeginLine = iota
326 stateDot
327 stateDotCR
328 stateCR
329 stateData
330 stateEOF
331 )
332 br := d.r.R
333 for n < len(b) && d.state != stateEOF {
334 var c byte
335 c, err = br.ReadByte()
336 if err != nil {
337 if err == io.EOF {
338 err = io.ErrUnexpectedEOF
339 }
340 break
341 }
342 switch d.state {
343 case stateBeginLine:
344 if c == '.' {
345 d.state = stateDot
346 continue
347 }
348 if c == '\r' {
349 d.state = stateCR
350 continue
351 }
352 d.state = stateData
353
354 case stateDot:
355 if c == '\r' {
356 d.state = stateDotCR
357 continue
358 }
359 if c == '\n' {
360 d.state = stateEOF
361 continue
362 }
363 d.state = stateData
364
365 case stateDotCR:
366 if c == '\n' {
367 d.state = stateEOF
368 continue
369 }
370
371
372 br.UnreadByte()
373 c = '\r'
374 d.state = stateData
375
376 case stateCR:
377 if c == '\n' {
378 d.state = stateBeginLine
379 break
380 }
381
382 br.UnreadByte()
383 c = '\r'
384 d.state = stateData
385
386 case stateData:
387 if c == '\r' {
388 d.state = stateCR
389 continue
390 }
391 if c == '\n' {
392 d.state = stateBeginLine
393 }
394 }
395 b[n] = c
396 n++
397 }
398 if err == nil && d.state == stateEOF {
399 err = io.EOF
400 }
401 if err != nil && d.r.dot == d {
402 d.r.dot = nil
403 }
404 return
405 }
406
407
408
409 func (r *Reader) closeDot() {
410 if r.dot == nil {
411 return
412 }
413 buf := make([]byte, 128)
414 for r.dot != nil {
415
416
417 r.dot.Read(buf)
418 }
419 }
420
421
422
423
424 func (r *Reader) ReadDotBytes() ([]byte, error) {
425 return io.ReadAll(r.DotReader())
426 }
427
428
429
430
431
432 func (r *Reader) ReadDotLines() ([]string, error) {
433
434
435
436 var v []string
437 var err error
438 for {
439 var line string
440 line, err = r.ReadLine()
441 if err != nil {
442 if err == io.EOF {
443 err = io.ErrUnexpectedEOF
444 }
445 break
446 }
447
448
449 if len(line) > 0 && line[0] == '.' {
450 if len(line) == 1 {
451 break
452 }
453 line = line[1:]
454 }
455 v = append(v, line)
456 }
457 return v, err
458 }
459
460 var colon = []byte(":")
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481 func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
482 return readMIMEHeader(r, math.MaxInt64, math.MaxInt64)
483 }
484
485
486
487 func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) {
488
489
490
491 var strs []string
492 hint := r.upcomingHeaderKeys()
493 if hint > 0 {
494 if hint > 1000 {
495 hint = 1000
496 }
497 strs = make([]string, hint)
498 }
499
500 m := make(MIMEHeader, hint)
501
502
503
504
505 maxMemory -= 400
506 const mapEntryOverhead = 200
507
508
509 if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
510 line, err := r.readLineSlice()
511 if err != nil {
512 return m, err
513 }
514 return m, ProtocolError("malformed MIME header initial line: " + string(line))
515 }
516
517 for {
518 kv, err := r.readContinuedLineSlice(mustHaveFieldNameColon)
519 if len(kv) == 0 {
520 return m, err
521 }
522
523
524 k, v, ok := bytes.Cut(kv, colon)
525 if !ok {
526 return m, ProtocolError("malformed MIME header line: " + string(kv))
527 }
528 key, ok := canonicalMIMEHeaderKey(k)
529 if !ok {
530 return m, ProtocolError("malformed MIME header line: " + string(kv))
531 }
532 for _, c := range v {
533 if !validHeaderValueByte(c) {
534 return m, ProtocolError("malformed MIME header line: " + string(kv))
535 }
536 }
537
538
539
540
541 if key == "" {
542 continue
543 }
544
545 maxHeaders--
546 if maxHeaders < 0 {
547 return nil, errors.New("message too large")
548 }
549
550
551 value := string(bytes.TrimLeft(v, " \t"))
552
553 vv := m[key]
554 if vv == nil {
555 maxMemory -= int64(len(key))
556 maxMemory -= mapEntryOverhead
557 }
558 maxMemory -= int64(len(value))
559 if maxMemory < 0 {
560
561
562 return m, errors.New("message too large")
563 }
564 if vv == nil && len(strs) > 0 {
565
566
567
568
569 vv, strs = strs[:1:1], strs[1:]
570 vv[0] = value
571 m[key] = vv
572 } else {
573 m[key] = append(vv, value)
574 }
575
576 if err != nil {
577 return m, err
578 }
579 }
580 }
581
582
583
584 func noValidation(_ []byte) error { return nil }
585
586
587
588
589 func mustHaveFieldNameColon(line []byte) error {
590 if bytes.IndexByte(line, ':') < 0 {
591 return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line))
592 }
593 return nil
594 }
595
596 var nl = []byte("\n")
597
598
599
600 func (r *Reader) upcomingHeaderKeys() (n int) {
601
602 r.R.Peek(1)
603 s := r.R.Buffered()
604 if s == 0 {
605 return
606 }
607 peek, _ := r.R.Peek(s)
608 for len(peek) > 0 && n < 1000 {
609 var line []byte
610 line, peek, _ = bytes.Cut(peek, nl)
611 if len(line) == 0 || (len(line) == 1 && line[0] == '\r') {
612
613 break
614 }
615 if line[0] == ' ' || line[0] == '\t' {
616
617 continue
618 }
619 n++
620 }
621 return n
622 }
623
624
625
626
627
628
629
630
631
632 func CanonicalMIMEHeaderKey(s string) string {
633
634 upper := true
635 for i := 0; i < len(s); i++ {
636 c := s[i]
637 if !validHeaderFieldByte(c) {
638 return s
639 }
640 if upper && 'a' <= c && c <= 'z' {
641 s, _ = canonicalMIMEHeaderKey([]byte(s))
642 return s
643 }
644 if !upper && 'A' <= c && c <= 'Z' {
645 s, _ = canonicalMIMEHeaderKey([]byte(s))
646 return s
647 }
648 upper = c == '-'
649 }
650 return s
651 }
652
653 const toLower = 'a' - 'A'
654
655
656
657
658
659
660
661
662
663 func validHeaderFieldByte(c byte) bool {
664
665
666
667
668 const mask = 0 |
669 (1<<(10)-1)<<'0' |
670 (1<<(26)-1)<<'a' |
671 (1<<(26)-1)<<'A' |
672 1<<'!' |
673 1<<'#' |
674 1<<'$' |
675 1<<'%' |
676 1<<'&' |
677 1<<'\'' |
678 1<<'*' |
679 1<<'+' |
680 1<<'-' |
681 1<<'.' |
682 1<<'^' |
683 1<<'_' |
684 1<<'`' |
685 1<<'|' |
686 1<<'~'
687 return ((uint64(1)<<c)&(mask&(1<<64-1)) |
688 (uint64(1)<<(c-64))&(mask>>64)) != 0
689 }
690
691
692
693
694
695
696
697
698
699
700
701
702
703 func validHeaderValueByte(c byte) bool {
704
705
706
707
708
709 const mask = 0 |
710 (1<<(0x7f-0x21)-1)<<0x21 |
711 1<<0x20 |
712 1<<0x09
713 return ((uint64(1)<<c)&^(mask&(1<<64-1)) |
714 (uint64(1)<<(c-64))&^(mask>>64)) == 0
715 }
716
717
718
719
720
721
722
723
724
725
726
727 func canonicalMIMEHeaderKey(a []byte) (_ string, ok bool) {
728
729 noCanon := false
730 for _, c := range a {
731 if validHeaderFieldByte(c) {
732 continue
733 }
734
735 if c == ' ' {
736
737
738
739 noCanon = true
740 continue
741 }
742 return string(a), false
743 }
744 if noCanon {
745 return string(a), true
746 }
747
748 upper := true
749 for i, c := range a {
750
751
752
753
754 if upper && 'a' <= c && c <= 'z' {
755 c -= toLower
756 } else if !upper && 'A' <= c && c <= 'Z' {
757 c += toLower
758 }
759 a[i] = c
760 upper = c == '-'
761 }
762 commonHeaderOnce.Do(initCommonHeader)
763
764
765
766 if v := commonHeader[string(a)]; v != "" {
767 return v, true
768 }
769 return string(a), true
770 }
771
772
773 var commonHeader map[string]string
774
775 var commonHeaderOnce sync.Once
776
777 func initCommonHeader() {
778 commonHeader = make(map[string]string)
779 for _, v := range []string{
780 "Accept",
781 "Accept-Charset",
782 "Accept-Encoding",
783 "Accept-Language",
784 "Accept-Ranges",
785 "Cache-Control",
786 "Cc",
787 "Connection",
788 "Content-Id",
789 "Content-Language",
790 "Content-Length",
791 "Content-Transfer-Encoding",
792 "Content-Type",
793 "Cookie",
794 "Date",
795 "Dkim-Signature",
796 "Etag",
797 "Expires",
798 "From",
799 "Host",
800 "If-Modified-Since",
801 "If-None-Match",
802 "In-Reply-To",
803 "Last-Modified",
804 "Location",
805 "Message-Id",
806 "Mime-Version",
807 "Pragma",
808 "Received",
809 "Return-Path",
810 "Server",
811 "Set-Cookie",
812 "Subject",
813 "To",
814 "User-Agent",
815 "Via",
816 "X-Forwarded-For",
817 "X-Imforwards",
818 "X-Powered-By",
819 } {
820 commonHeader[v] = v
821 }
822 }
823
View as plain text