// Copyright 2010 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package textproto import ( "bufio" "bytes" "errors" "fmt" "io" "math" "strconv" "strings" "sync" ) // TODO: This should be a distinguishable error (ErrMessageTooLarge) // to allow mime/multipart to detect it. var errMessageTooLarge = errors.New("message too large") // A Reader implements convenience methods for reading requests // or responses from a text protocol network connection. type Reader struct { R *bufio.Reader dot *dotReader buf []byte // a re-usable buffer for readContinuedLineSlice } // NewReader returns a new [Reader] reading from r. // // To avoid denial of service attacks, the provided [bufio.Reader] // should be reading from an [io.LimitReader] or similar Reader to bound // the size of responses. func NewReader(r *bufio.Reader) *Reader { return &Reader{R: r} } // ReadLine reads a single line from r, // eliding the final \n or \r\n from the returned string. func (r *Reader) ReadLine() (string, error) { line, err := r.readLineSlice(-1) return string(line), err } // ReadLineBytes is like [Reader.ReadLine] but returns a []byte instead of a string. func (r *Reader) ReadLineBytes() ([]byte, error) { line, err := r.readLineSlice(-1) if line != nil { line = bytes.Clone(line) } return line, err } // readLineSlice reads a single line from r, // up to lim bytes long (or unlimited if lim is less than 0), // eliding the final \r or \r\n from the returned string. func (r *Reader) readLineSlice(lim int64) ([]byte, error) { r.closeDot() var line []byte for { l, more, err := r.R.ReadLine() if err != nil { return nil, err } if lim >= 0 && int64(len(line))+int64(len(l)) > lim { return nil, errMessageTooLarge } // Avoid the copy if the first call produced a full line. if line == nil && !more { return l, nil } line = append(line, l...) if !more { break } } return line, nil } // ReadContinuedLine reads a possibly continued line from r, // eliding the final trailing ASCII white space. // Lines after the first are considered continuations if they // begin with a space or tab character. In the returned data, // continuation lines are separated from the previous line // only by a single space: the newline and leading white space // are removed. // // For example, consider this input: // // Line 1 // continued... // Line 2 // // The first call to ReadContinuedLine will return "Line 1 continued..." // and the second will return "Line 2". // // Empty lines are never continued. func (r *Reader) ReadContinuedLine() (string, error) { line, err := r.readContinuedLineSlice(-1, noValidation) return string(line), err } // trim returns s with leading and trailing spaces and tabs removed. // It does not assume Unicode or UTF-8. func trim(s []byte) []byte { i := 0 for i < len(s) && (s[i] == ' ' || s[i] == '\t') { i++ } n := len(s) for n > i && (s[n-1] == ' ' || s[n-1] == '\t') { n-- } return s[i:n] } // ReadContinuedLineBytes is like [Reader.ReadContinuedLine] but // returns a []byte instead of a string. func (r *Reader) ReadContinuedLineBytes() ([]byte, error) { line, err := r.readContinuedLineSlice(-1, noValidation) if line != nil { line = bytes.Clone(line) } return line, err } // readContinuedLineSlice reads continued lines from the reader buffer, // returning a byte slice with all lines. The validateFirstLine function // is run on the first read line, and if it returns an error then this // error is returned from readContinuedLineSlice. // It reads up to lim bytes of data (or unlimited if lim is less than 0). func (r *Reader) readContinuedLineSlice(lim int64, validateFirstLine func([]byte) error) ([]byte, error) { if validateFirstLine == nil { return nil, fmt.Errorf("missing validateFirstLine func") } // Read the first line. line, err := r.readLineSlice(lim) if err != nil { return nil, err } if len(line) == 0 { // blank line - no continuation return line, nil } if err := validateFirstLine(line); err != nil { return nil, err } // Optimistically assume that we have started to buffer the next line // and it starts with an ASCII letter (the next header key), or a blank // line, so we can avoid copying that buffered data around in memory // and skipping over non-existent whitespace. if r.R.Buffered() > 1 { peek, _ := r.R.Peek(2) if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') || len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' { return trim(line), nil } } // ReadByte or the next readLineSlice will flush the read buffer; // copy the slice into buf. r.buf = append(r.buf[:0], trim(line)...) if lim < 0 { lim = math.MaxInt64 } lim -= int64(len(r.buf)) // Read continuation lines. for r.skipSpace() > 0 { r.buf = append(r.buf, ' ') if int64(len(r.buf)) >= lim { return nil, errMessageTooLarge } line, err := r.readLineSlice(lim - int64(len(r.buf))) if err != nil { break } r.buf = append(r.buf, trim(line)...) } return r.buf, nil } // skipSpace skips R over all spaces and returns the number of bytes skipped. func (r *Reader) skipSpace() int { n := 0 for { c, err := r.R.ReadByte() if err != nil { // Bufio will keep err until next read. break } if c != ' ' && c != '\t' { r.R.UnreadByte() break } n++ } return n } func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) { line, err := r.ReadLine() if err != nil { return } return parseCodeLine(line, expectCode) } func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) { if len(line) < 4 || line[3] != ' ' && line[3] != '-' { err = ProtocolError("short response: " + line) return } continued = line[3] == '-' code, err = strconv.Atoi(line[0:3]) if err != nil || code < 100 { err = ProtocolError("invalid response code: " + line) return } message = line[4:] if 1 <= expectCode && expectCode < 10 && code/100 != expectCode || 10 <= expectCode && expectCode < 100 && code/10 != expectCode || 100 <= expectCode && expectCode < 1000 && code != expectCode { err = &Error{code, message} } return } // ReadCodeLine reads a response code line of the form // // code message // // where code is a three-digit status code and the message // extends to the rest of the line. An example of such a line is: // // 220 plan9.bell-labs.com ESMTP // // If the prefix of the status does not match the digits in expectCode, // ReadCodeLine returns with err set to &Error{code, message}. // For example, if expectCode is 31, an error will be returned if // the status is not in the range [310,319]. // // If the response is multi-line, ReadCodeLine returns an error. // // An expectCode <= 0 disables the check of the status code. func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) { code, continued, message, err := r.readCodeLine(expectCode) if err == nil && continued { err = ProtocolError("unexpected multi-line response: " + message) } return } // ReadResponse reads a multi-line response of the form: // // code-message line 1 // code-message line 2 // ... // code message line n // // where code is a three-digit status code. The first line starts with the // code and a hyphen. The response is terminated by a line that starts // with the same code followed by a space. Each line in message is // separated by a newline (\n). // // See page 36 of RFC 959 (https://www.ietf.org/rfc/rfc959.txt) for // details of another form of response accepted: // // code-message line 1 // message line 2 // ... // code message line n // // If the prefix of the status does not match the digits in expectCode, // ReadResponse returns with err set to &Error{code, message}. // For example, if expectCode is 31, an error will be returned if // the status is not in the range [310,319]. // // An expectCode <= 0 disables the check of the status code. func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) { code, continued, message, err := r.readCodeLine(expectCode) multi := continued for continued { line, err := r.ReadLine() if err != nil { return 0, "", err } var code2 int var moreMessage string code2, continued, moreMessage, err = parseCodeLine(line, 0) if err != nil || code2 != code { message += "\n" + strings.TrimRight(line, "\r\n") continued = true continue } message += "\n" + moreMessage } if err != nil && multi && message != "" { // replace one line error message with all lines (full message) err = &Error{code, message} } return } // DotReader returns a new [Reader] that satisfies Reads using the // decoded text of a dot-encoded block read from r. // The returned Reader is only valid until the next call // to a method on r. // // Dot encoding is a common framing used for data blocks // in text protocols such as SMTP. The data consists of a sequence // of lines, each of which ends in "\r\n". The sequence itself // ends at a line containing just a dot: ".\r\n". Lines beginning // with a dot are escaped with an additional dot to avoid // looking like the end of the sequence. // // The decoded form returned by the Reader's Read method // rewrites the "\r\n" line endings into the simpler "\n", // removes leading dot escapes if present, and stops with error [io.EOF] // after consuming (and discarding) the end-of-sequence line. func (r *Reader) DotReader() io.Reader { r.closeDot() r.dot = &dotReader{r: r} return r.dot } type dotReader struct { r *Reader state int } // Read satisfies reads by decoding dot-encoded data read from d.r. func (d *dotReader) Read(b []byte) (n int, err error) { // Run data through a simple state machine to // elide leading dots, rewrite trailing \r\n into \n, // and detect ending .\r\n line. const ( stateBeginLine = iota // beginning of line; initial state; must be zero stateDot // read . at beginning of line stateDotCR // read .\r at beginning of line stateCR // read \r (possibly at end of line) stateData // reading data in middle of line stateEOF // reached .\r\n end marker line ) br := d.r.R for n < len(b) && d.state != stateEOF { var c byte c, err = br.ReadByte() if err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF } break } switch d.state { case stateBeginLine: if c == '.' { d.state = stateDot continue } if c == '\r' { d.state = stateCR continue } d.state = stateData case stateDot: if c == '\r' { d.state = stateDotCR continue } if c == '\n' { d.state = stateEOF continue } d.state = stateData case stateDotCR: if c == '\n' { d.state = stateEOF continue } // Not part of .\r\n. // Consume leading dot and emit saved \r. br.UnreadByte() c = '\r' d.state = stateData case stateCR: if c == '\n' { d.state = stateBeginLine break } // Not part of \r\n. Emit saved \r br.UnreadByte() c = '\r' d.state = stateData case stateData: if c == '\r' { d.state = stateCR continue } if c == '\n' { d.state = stateBeginLine } } b[n] = c n++ } if err == nil && d.state == stateEOF { err = io.EOF } if err != nil && d.r.dot == d { d.r.dot = nil } return } // closeDot drains the current DotReader if any, // making sure that it reads until the ending dot line. func (r *Reader) closeDot() { if r.dot == nil { return } buf := make([]byte, 128) for r.dot != nil { // When Read reaches EOF or an error, // it will set r.dot == nil. r.dot.Read(buf) } } // ReadDotBytes reads a dot-encoding and returns the decoded data. // // See the documentation for the [Reader.DotReader] method for details about dot-encoding. func (r *Reader) ReadDotBytes() ([]byte, error) { return io.ReadAll(r.DotReader()) } // ReadDotLines reads a dot-encoding and returns a slice // containing the decoded lines, with the final \r\n or \n elided from each. // // See the documentation for the [Reader.DotReader] method for details about dot-encoding. func (r *Reader) ReadDotLines() ([]string, error) { // We could use ReadDotBytes and then Split it, // but reading a line at a time avoids needing a // large contiguous block of memory and is simpler. var v []string var err error for { var line string line, err = r.ReadLine() if err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF } break } // Dot by itself marks end; otherwise cut one dot. if len(line) > 0 && line[0] == '.' { if len(line) == 1 { break } line = line[1:] } v = append(v, line) } return v, err } var colon = []byte(":") // ReadMIMEHeader reads a MIME-style header from r. // The header is a sequence of possibly continued Key: Value lines // ending in a blank line. // The returned map m maps [CanonicalMIMEHeaderKey](key) to a // sequence of values in the same order encountered in the input. // // For example, consider this input: // // My-Key: Value 1 // Long-Key: Even // Longer Value // My-Key: Value 2 // // Given that input, ReadMIMEHeader returns the map: // // map[string][]string{ // "My-Key": {"Value 1", "Value 2"}, // "Long-Key": {"Even Longer Value"}, // } func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { return readMIMEHeader(r, math.MaxInt64, math.MaxInt64) } // readMIMEHeader is a version of ReadMIMEHeader which takes a limit on the header size. // It is called by the mime/multipart package. func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) { // Avoid lots of small slice allocations later by allocating one // large one ahead of time which we'll cut up into smaller // slices. If this isn't big enough later, we allocate small ones. var strs []string hint := r.upcomingHeaderKeys() if hint > 0 { if hint > 1000 { hint = 1000 // set a cap to avoid overallocation } strs = make([]string, hint) } m := make(MIMEHeader, hint) // Account for 400 bytes of overhead for the MIMEHeader, plus 200 bytes per entry. // Benchmarking map creation as of go1.20, a one-entry MIMEHeader is 416 bytes and large // MIMEHeaders average about 200 bytes per entry. maxMemory -= 400 const mapEntryOverhead = 200 // The first line cannot start with a leading space. if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') { const errorLimit = 80 // arbitrary limit on how much of the line we'll quote line, err := r.readLineSlice(errorLimit) if err != nil { return m, err } return m, ProtocolError("malformed MIME header initial line: " + string(line)) } for { kv, err := r.readContinuedLineSlice(maxMemory, mustHaveFieldNameColon) if len(kv) == 0 { return m, err } // Key ends at first colon. k, v, ok := bytes.Cut(kv, colon) if !ok { return m, ProtocolError("malformed MIME header line: " + string(kv)) } key, ok := canonicalMIMEHeaderKey(k) if !ok { return m, ProtocolError("malformed MIME header line: " + string(kv)) } for _, c := range v { if !validHeaderValueByte(c) { return m, ProtocolError("malformed MIME header line: " + string(kv)) } } // As per RFC 7230 field-name is a token, tokens consist of one or more chars. // We could return a ProtocolError here, but better to be liberal in what we // accept, so if we get an empty key, skip it. if key == "" { continue } maxHeaders-- if maxHeaders < 0 { return nil, errMessageTooLarge } // Skip initial spaces in value. value := string(bytes.TrimLeft(v, " \t")) vv := m[key] if vv == nil { maxMemory -= int64(len(key)) maxMemory -= mapEntryOverhead } maxMemory -= int64(len(value)) if maxMemory < 0 { return m, errMessageTooLarge } if vv == nil && len(strs) > 0 { // More than likely this will be a single-element key. // Most headers aren't multi-valued. // Set the capacity on strs[0] to 1, so any future append // won't extend the slice into the other strings. vv, strs = strs[:1:1], strs[1:] vv[0] = value m[key] = vv } else { m[key] = append(vv, value) } if err != nil { return m, err } } } // noValidation is a no-op validation func for readContinuedLineSlice // that permits any lines. func noValidation(_ []byte) error { return nil } // mustHaveFieldNameColon ensures that, per RFC 7230, the // field-name is on a single line, so the first line must // contain a colon. func mustHaveFieldNameColon(line []byte) error { if bytes.IndexByte(line, ':') < 0 { return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line)) } return nil } var nl = []byte("\n") // upcomingHeaderKeys returns an approximation of the number of keys // that will be in this header. If it gets confused, it returns 0. func (r *Reader) upcomingHeaderKeys() (n int) { // Try to determine the 'hint' size. r.R.Peek(1) // force a buffer load if empty s := r.R.Buffered() if s == 0 { return } peek, _ := r.R.Peek(s) for len(peek) > 0 && n < 1000 { var line []byte line, peek, _ = bytes.Cut(peek, nl) if len(line) == 0 || (len(line) == 1 && line[0] == '\r') { // Blank line separating headers from the body. break } if line[0] == ' ' || line[0] == '\t' { // Folded continuation of the previous line. continue } n++ } return n } // CanonicalMIMEHeaderKey returns the canonical format of the // MIME header key s. The canonicalization converts the first // letter and any letter following a hyphen to upper case; // the rest are converted to lowercase. For example, the // canonical key for "accept-encoding" is "Accept-Encoding". // MIME header keys are assumed to be ASCII only. // If s contains a space or invalid header field bytes, it is // returned without modifications. func CanonicalMIMEHeaderKey(s string) string { // Quick check for canonical encoding. upper := true for i := 0; i < len(s); i++ { c := s[i] if !validHeaderFieldByte(c) { return s } if upper && 'a' <= c && c <= 'z' { s, _ = canonicalMIMEHeaderKey([]byte(s)) return s } if !upper && 'A' <= c && c <= 'Z' { s, _ = canonicalMIMEHeaderKey([]byte(s)) return s } upper = c == '-' } return s } const toLower = 'a' - 'A' // validHeaderFieldByte reports whether c is a valid byte in a header // field name. RFC 7230 says: // // header-field = field-name ":" OWS field-value OWS // field-name = token // tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / // "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA // token = 1*tchar func validHeaderFieldByte(c byte) bool { // mask is a 128-bit bitmap with 1s for allowed bytes, // so that the byte c can be tested with a shift and an and. // If c >= 128, then 1<>64)) != 0 } // validHeaderValueByte reports whether c is a valid byte in a header // field value. RFC 7230 says: // // field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] // field-vchar = VCHAR / obs-text // obs-text = %x80-FF // // RFC 5234 says: // // HTAB = %x09 // SP = %x20 // VCHAR = %x21-7E func validHeaderValueByte(c byte) bool { // mask is a 128-bit bitmap with 1s for allowed bytes, // so that the byte c can be tested with a shift and an and. // If c >= 128, then 1<>64)) == 0 } // canonicalMIMEHeaderKey is like CanonicalMIMEHeaderKey but is // allowed to mutate the provided byte slice before returning the // string. // // For invalid inputs (if a contains spaces or non-token bytes), a // is unchanged and a string copy is returned. // // ok is true if the header key contains only valid characters and spaces. // ReadMIMEHeader accepts header keys containing spaces, but does not // canonicalize them. func canonicalMIMEHeaderKey(a []byte) (_ string, ok bool) { // See if a looks like a header key. If not, return it unchanged. noCanon := false for _, c := range a { if validHeaderFieldByte(c) { continue } // Don't canonicalize. if c == ' ' { // We accept invalid headers with a space before the // colon, but must not canonicalize them. // See https://go.dev/issue/34540. noCanon = true continue } return string(a), false } if noCanon { return string(a), true } upper := true for i, c := range a { // Canonicalize: first letter upper case // and upper case after each dash. // (Host, User-Agent, If-Modified-Since). // MIME headers are ASCII only, so no Unicode issues. if upper && 'a' <= c && c <= 'z' { c -= toLower } else if !upper && 'A' <= c && c <= 'Z' { c += toLower } a[i] = c upper = c == '-' // for next time } commonHeaderOnce.Do(initCommonHeader) // The compiler recognizes m[string(byteSlice)] as a special // case, so a copy of a's bytes into a new string does not // happen in this map lookup: if v := commonHeader[string(a)]; v != "" { return v, true } return string(a), true } // commonHeader interns common header strings. var commonHeader map[string]string var commonHeaderOnce sync.Once func initCommonHeader() { commonHeader = make(map[string]string) for _, v := range []string{ "Accept", "Accept-Charset", "Accept-Encoding", "Accept-Language", "Accept-Ranges", "Cache-Control", "Cc", "Connection", "Content-Id", "Content-Language", "Content-Length", "Content-Transfer-Encoding", "Content-Type", "Cookie", "Date", "Dkim-Signature", "Etag", "Expires", "From", "Host", "If-Modified-Since", "If-None-Match", "In-Reply-To", "Last-Modified", "Location", "Message-Id", "Mime-Version", "Pragma", "Received", "Return-Path", "Server", "Set-Cookie", "Subject", "To", "User-Agent", "Via", "X-Forwarded-For", "X-Imforwards", "X-Powered-By", } { commonHeader[v] = v } }