1
2
3
4
5 package textproto
6
7 import (
8 "bufio"
9 "bytes"
10 "fmt"
11 "io"
12 "strconv"
13 "strings"
14 "sync"
15 )
16
17
18
19 type Reader struct {
20 R *bufio.Reader
21 dot *dotReader
22 buf []byte
23 }
24
25
26
27
28
29
30 func NewReader(r *bufio.Reader) *Reader {
31 return &Reader{R: r}
32 }
33
34
35
36 func (r *Reader) ReadLine() (string, error) {
37 line, err := r.readLineSlice()
38 return string(line), err
39 }
40
41
42 func (r *Reader) ReadLineBytes() ([]byte, error) {
43 line, err := r.readLineSlice()
44 if line != nil {
45 line = bytes.Clone(line)
46 }
47 return line, err
48 }
49
50 func (r *Reader) readLineSlice() ([]byte, error) {
51 r.closeDot()
52 var line []byte
53 for {
54 l, more, err := r.R.ReadLine()
55 if err != nil {
56 return nil, err
57 }
58
59 if line == nil && !more {
60 return l, nil
61 }
62 line = append(line, l...)
63 if !more {
64 break
65 }
66 }
67 return line, nil
68 }
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88 func (r *Reader) ReadContinuedLine() (string, error) {
89 line, err := r.readContinuedLineSlice(noValidation)
90 return string(line), err
91 }
92
93
94
95 func trim(s []byte) []byte {
96 i := 0
97 for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
98 i++
99 }
100 n := len(s)
101 for n > i && (s[n-1] == ' ' || s[n-1] == '\t') {
102 n--
103 }
104 return s[i:n]
105 }
106
107
108
109 func (r *Reader) ReadContinuedLineBytes() ([]byte, error) {
110 line, err := r.readContinuedLineSlice(noValidation)
111 if line != nil {
112 line = bytes.Clone(line)
113 }
114 return line, err
115 }
116
117
118
119
120
121 func (r *Reader) readContinuedLineSlice(validateFirstLine func([]byte) error) ([]byte, error) {
122 if validateFirstLine == nil {
123 return nil, fmt.Errorf("missing validateFirstLine func")
124 }
125
126
127 line, err := r.readLineSlice()
128 if err != nil {
129 return nil, err
130 }
131 if len(line) == 0 {
132 return line, nil
133 }
134
135 if err := validateFirstLine(line); err != nil {
136 return nil, err
137 }
138
139
140
141
142
143 if r.R.Buffered() > 1 {
144 peek, _ := r.R.Peek(2)
145 if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') ||
146 len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' {
147 return trim(line), nil
148 }
149 }
150
151
152
153 r.buf = append(r.buf[:0], trim(line)...)
154
155
156 for r.skipSpace() > 0 {
157 line, err := r.readLineSlice()
158 if err != nil {
159 break
160 }
161 r.buf = append(r.buf, ' ')
162 r.buf = append(r.buf, trim(line)...)
163 }
164 return r.buf, nil
165 }
166
167
168 func (r *Reader) skipSpace() int {
169 n := 0
170 for {
171 c, err := r.R.ReadByte()
172 if err != nil {
173
174 break
175 }
176 if c != ' ' && c != '\t' {
177 r.R.UnreadByte()
178 break
179 }
180 n++
181 }
182 return n
183 }
184
185 func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
186 line, err := r.ReadLine()
187 if err != nil {
188 return
189 }
190 return parseCodeLine(line, expectCode)
191 }
192
193 func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) {
194 if len(line) < 4 || line[3] != ' ' && line[3] != '-' {
195 err = ProtocolError("short response: " + line)
196 return
197 }
198 continued = line[3] == '-'
199 code, err = strconv.Atoi(line[0:3])
200 if err != nil || code < 100 {
201 err = ProtocolError("invalid response code: " + line)
202 return
203 }
204 message = line[4:]
205 if 1 <= expectCode && expectCode < 10 && code/100 != expectCode ||
206 10 <= expectCode && expectCode < 100 && code/10 != expectCode ||
207 100 <= expectCode && expectCode < 1000 && code != expectCode {
208 err = &Error{code, message}
209 }
210 return
211 }
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230 func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) {
231 code, continued, message, err := r.readCodeLine(expectCode)
232 if err == nil && continued {
233 err = ProtocolError("unexpected multi-line response: " + message)
234 }
235 return
236 }
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264 func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) {
265 code, continued, message, err := r.readCodeLine(expectCode)
266 multi := continued
267 for continued {
268 line, err := r.ReadLine()
269 if err != nil {
270 return 0, "", err
271 }
272
273 var code2 int
274 var moreMessage string
275 code2, continued, moreMessage, err = parseCodeLine(line, 0)
276 if err != nil || code2 != code {
277 message += "\n" + strings.TrimRight(line, "\r\n")
278 continued = true
279 continue
280 }
281 message += "\n" + moreMessage
282 }
283 if err != nil && multi && message != "" {
284
285 err = &Error{code, message}
286 }
287 return
288 }
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306 func (r *Reader) DotReader() io.Reader {
307 r.closeDot()
308 r.dot = &dotReader{r: r}
309 return r.dot
310 }
311
312 type dotReader struct {
313 r *Reader
314 state int
315 }
316
317
318 func (d *dotReader) Read(b []byte) (n int, err error) {
319
320
321
322 const (
323 stateBeginLine = iota
324 stateDot
325 stateDotCR
326 stateCR
327 stateData
328 stateEOF
329 )
330 br := d.r.R
331 for n < len(b) && d.state != stateEOF {
332 var c byte
333 c, err = br.ReadByte()
334 if err != nil {
335 if err == io.EOF {
336 err = io.ErrUnexpectedEOF
337 }
338 break
339 }
340 switch d.state {
341 case stateBeginLine:
342 if c == '.' {
343 d.state = stateDot
344 continue
345 }
346 if c == '\r' {
347 d.state = stateCR
348 continue
349 }
350 d.state = stateData
351
352 case stateDot:
353 if c == '\r' {
354 d.state = stateDotCR
355 continue
356 }
357 if c == '\n' {
358 d.state = stateEOF
359 continue
360 }
361 d.state = stateData
362
363 case stateDotCR:
364 if c == '\n' {
365 d.state = stateEOF
366 continue
367 }
368
369
370 br.UnreadByte()
371 c = '\r'
372 d.state = stateData
373
374 case stateCR:
375 if c == '\n' {
376 d.state = stateBeginLine
377 break
378 }
379
380 br.UnreadByte()
381 c = '\r'
382 d.state = stateData
383
384 case stateData:
385 if c == '\r' {
386 d.state = stateCR
387 continue
388 }
389 if c == '\n' {
390 d.state = stateBeginLine
391 }
392 }
393 b[n] = c
394 n++
395 }
396 if err == nil && d.state == stateEOF {
397 err = io.EOF
398 }
399 if err != nil && d.r.dot == d {
400 d.r.dot = nil
401 }
402 return
403 }
404
405
406
407 func (r *Reader) closeDot() {
408 if r.dot == nil {
409 return
410 }
411 buf := make([]byte, 128)
412 for r.dot != nil {
413
414
415 r.dot.Read(buf)
416 }
417 }
418
419
420
421
422 func (r *Reader) ReadDotBytes() ([]byte, error) {
423 return io.ReadAll(r.DotReader())
424 }
425
426
427
428
429
430 func (r *Reader) ReadDotLines() ([]string, error) {
431
432
433
434 var v []string
435 var err error
436 for {
437 var line string
438 line, err = r.ReadLine()
439 if err != nil {
440 if err == io.EOF {
441 err = io.ErrUnexpectedEOF
442 }
443 break
444 }
445
446
447 if len(line) > 0 && line[0] == '.' {
448 if len(line) == 1 {
449 break
450 }
451 line = line[1:]
452 }
453 v = append(v, line)
454 }
455 return v, err
456 }
457
458 var colon = []byte(":")
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479 func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
480
481
482
483 var strs []string
484 hint := r.upcomingHeaderNewlines()
485 if hint > 0 {
486 strs = make([]string, hint)
487 }
488
489 m := make(MIMEHeader, hint)
490
491
492 if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
493 line, err := r.readLineSlice()
494 if err != nil {
495 return m, err
496 }
497 return m, ProtocolError("malformed MIME header initial line: " + string(line))
498 }
499
500 for {
501 kv, err := r.readContinuedLineSlice(mustHaveFieldNameColon)
502 if len(kv) == 0 {
503 return m, err
504 }
505
506
507 k, v, ok := bytes.Cut(kv, colon)
508 if !ok {
509 return m, ProtocolError("malformed MIME header line: " + string(kv))
510 }
511 key, ok := canonicalMIMEHeaderKey(k)
512 if !ok {
513 return m, ProtocolError("malformed MIME header line: " + string(kv))
514 }
515 for _, c := range v {
516 if !validHeaderValueByte(c) {
517 return m, ProtocolError("malformed MIME header line: " + string(kv))
518 }
519 }
520
521
522
523
524 if key == "" {
525 continue
526 }
527
528
529 value := strings.TrimLeft(string(v), " \t")
530
531 vv := m[key]
532 if vv == nil && len(strs) > 0 {
533
534
535
536
537 vv, strs = strs[:1:1], strs[1:]
538 vv[0] = value
539 m[key] = vv
540 } else {
541 m[key] = append(vv, value)
542 }
543
544 if err != nil {
545 return m, err
546 }
547 }
548 }
549
550
551
552 func noValidation(_ []byte) error { return nil }
553
554
555
556
557 func mustHaveFieldNameColon(line []byte) error {
558 if bytes.IndexByte(line, ':') < 0 {
559 return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line))
560 }
561 return nil
562 }
563
564 var nl = []byte("\n")
565
566
567
568 func (r *Reader) upcomingHeaderNewlines() (n int) {
569
570 r.R.Peek(1)
571 s := r.R.Buffered()
572 if s == 0 {
573 return
574 }
575 peek, _ := r.R.Peek(s)
576 return bytes.Count(peek, nl)
577 }
578
579
580
581
582
583
584
585
586
587 func CanonicalMIMEHeaderKey(s string) string {
588
589 upper := true
590 for i := 0; i < len(s); i++ {
591 c := s[i]
592 if !validHeaderFieldByte(c) {
593 return s
594 }
595 if upper && 'a' <= c && c <= 'z' {
596 s, _ = canonicalMIMEHeaderKey([]byte(s))
597 return s
598 }
599 if !upper && 'A' <= c && c <= 'Z' {
600 s, _ = canonicalMIMEHeaderKey([]byte(s))
601 return s
602 }
603 upper = c == '-'
604 }
605 return s
606 }
607
608 const toLower = 'a' - 'A'
609
610
611
612
613
614
615
616
617
618 func validHeaderFieldByte(c byte) bool {
619
620
621
622
623 const mask = 0 |
624 (1<<(10)-1)<<'0' |
625 (1<<(26)-1)<<'a' |
626 (1<<(26)-1)<<'A' |
627 1<<'!' |
628 1<<'#' |
629 1<<'$' |
630 1<<'%' |
631 1<<'&' |
632 1<<'\'' |
633 1<<'*' |
634 1<<'+' |
635 1<<'-' |
636 1<<'.' |
637 1<<'^' |
638 1<<'_' |
639 1<<'`' |
640 1<<'|' |
641 1<<'~'
642 return ((uint64(1)<<c)&(mask&(1<<64-1)) |
643 (uint64(1)<<(c-64))&(mask>>64)) != 0
644 }
645
646
647
648
649
650
651
652
653
654
655
656
657
658 func validHeaderValueByte(c byte) bool {
659
660
661
662
663
664 const mask = 0 |
665 (1<<(0x7f-0x21)-1)<<0x21 |
666 1<<0x20 |
667 1<<0x09
668 return ((uint64(1)<<c)&^(mask&(1<<64-1)) |
669 (uint64(1)<<(c-64))&^(mask>>64)) == 0
670 }
671
672
673
674
675
676
677
678
679
680
681
682 func canonicalMIMEHeaderKey(a []byte) (_ string, ok bool) {
683
684 noCanon := false
685 for _, c := range a {
686 if validHeaderFieldByte(c) {
687 continue
688 }
689
690 if c == ' ' {
691
692
693
694 noCanon = true
695 continue
696 }
697 return string(a), false
698 }
699 if noCanon {
700 return string(a), true
701 }
702
703 upper := true
704 for i, c := range a {
705
706
707
708
709 if upper && 'a' <= c && c <= 'z' {
710 c -= toLower
711 } else if !upper && 'A' <= c && c <= 'Z' {
712 c += toLower
713 }
714 a[i] = c
715 upper = c == '-'
716 }
717 commonHeaderOnce.Do(initCommonHeader)
718
719
720
721 if v := commonHeader[string(a)]; v != "" {
722 return v, true
723 }
724 return string(a), true
725 }
726
727
728 var commonHeader map[string]string
729
730 var commonHeaderOnce sync.Once
731
732 func initCommonHeader() {
733 commonHeader = make(map[string]string)
734 for _, v := range []string{
735 "Accept",
736 "Accept-Charset",
737 "Accept-Encoding",
738 "Accept-Language",
739 "Accept-Ranges",
740 "Cache-Control",
741 "Cc",
742 "Connection",
743 "Content-Id",
744 "Content-Language",
745 "Content-Length",
746 "Content-Transfer-Encoding",
747 "Content-Type",
748 "Cookie",
749 "Date",
750 "Dkim-Signature",
751 "Etag",
752 "Expires",
753 "From",
754 "Host",
755 "If-Modified-Since",
756 "If-None-Match",
757 "In-Reply-To",
758 "Last-Modified",
759 "Location",
760 "Message-Id",
761 "Mime-Version",
762 "Pragma",
763 "Received",
764 "Return-Path",
765 "Server",
766 "Set-Cookie",
767 "Subject",
768 "To",
769 "User-Agent",
770 "Via",
771 "X-Forwarded-For",
772 "X-Imforwards",
773 "X-Powered-By",
774 } {
775 commonHeader[v] = v
776 }
777 }
778
View as plain text