1
2
3
4
5
6 package cookiejar
7
8 import (
9 "errors"
10 "fmt"
11 "net"
12 "net/http"
13 "net/http/internal/ascii"
14 "net/url"
15 "sort"
16 "strings"
17 "sync"
18 "time"
19 )
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35 type PublicSuffixList interface {
36
37
38
39
40
41 PublicSuffix(domain string) string
42
43
44
45
46 String() string
47 }
48
49
50 type Options struct {
51
52
53
54
55
56
57 PublicSuffixList PublicSuffixList
58 }
59
60
61 type Jar struct {
62 psList PublicSuffixList
63
64
65 mu sync.Mutex
66
67
68
69 entries map[string]map[string]entry
70
71
72
73 nextSeqNum uint64
74 }
75
76
77
78 func New(o *Options) (*Jar, error) {
79 jar := &Jar{
80 entries: make(map[string]map[string]entry),
81 }
82 if o != nil {
83 jar.psList = o.PublicSuffixList
84 }
85 return jar, nil
86 }
87
88
89
90
91
92 type entry struct {
93 Name string
94 Value string
95 Domain string
96 Path string
97 SameSite string
98 Secure bool
99 HttpOnly bool
100 Persistent bool
101 HostOnly bool
102 Expires time.Time
103 Creation time.Time
104 LastAccess time.Time
105
106
107
108
109 seqNum uint64
110 }
111
112
113 func (e *entry) id() string {
114 return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name)
115 }
116
117
118
119
120 func (e *entry) shouldSend(https bool, host, path string) bool {
121 return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure)
122 }
123
124
125
126
127 func (e *entry) domainMatch(host string) bool {
128 if e.Domain == host {
129 return true
130 }
131 return !e.HostOnly && hasDotSuffix(host, e.Domain)
132 }
133
134
135 func (e *entry) pathMatch(requestPath string) bool {
136 if requestPath == e.Path {
137 return true
138 }
139 if strings.HasPrefix(requestPath, e.Path) {
140 if e.Path[len(e.Path)-1] == '/' {
141 return true
142 } else if requestPath[len(e.Path)] == '/' {
143 return true
144 }
145 }
146 return false
147 }
148
149
150 func hasDotSuffix(s, suffix string) bool {
151 return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
152 }
153
154
155
156
157 func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) {
158 return j.cookies(u, time.Now())
159 }
160
161
162 func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
163 if u.Scheme != "http" && u.Scheme != "https" {
164 return cookies
165 }
166 host, err := canonicalHost(u.Host)
167 if err != nil {
168 return cookies
169 }
170 key := jarKey(host, j.psList)
171
172 j.mu.Lock()
173 defer j.mu.Unlock()
174
175 submap := j.entries[key]
176 if submap == nil {
177 return cookies
178 }
179
180 https := u.Scheme == "https"
181 path := u.Path
182 if path == "" {
183 path = "/"
184 }
185
186 modified := false
187 var selected []entry
188 for id, e := range submap {
189 if e.Persistent && !e.Expires.After(now) {
190 delete(submap, id)
191 modified = true
192 continue
193 }
194 if !e.shouldSend(https, host, path) {
195 continue
196 }
197 e.LastAccess = now
198 submap[id] = e
199 selected = append(selected, e)
200 modified = true
201 }
202 if modified {
203 if len(submap) == 0 {
204 delete(j.entries, key)
205 } else {
206 j.entries[key] = submap
207 }
208 }
209
210
211
212 sort.Slice(selected, func(i, j int) bool {
213 s := selected
214 if len(s[i].Path) != len(s[j].Path) {
215 return len(s[i].Path) > len(s[j].Path)
216 }
217 if ret := s[i].Creation.Compare(s[j].Creation); ret != 0 {
218 return ret < 0
219 }
220 return s[i].seqNum < s[j].seqNum
221 })
222 for _, e := range selected {
223 cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value})
224 }
225
226 return cookies
227 }
228
229
230
231
232 func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
233 j.setCookies(u, cookies, time.Now())
234 }
235
236
237 func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
238 if len(cookies) == 0 {
239 return
240 }
241 if u.Scheme != "http" && u.Scheme != "https" {
242 return
243 }
244 host, err := canonicalHost(u.Host)
245 if err != nil {
246 return
247 }
248 key := jarKey(host, j.psList)
249 defPath := defaultPath(u.Path)
250
251 j.mu.Lock()
252 defer j.mu.Unlock()
253
254 submap := j.entries[key]
255
256 modified := false
257 for _, cookie := range cookies {
258 e, remove, err := j.newEntry(cookie, now, defPath, host)
259 if err != nil {
260 continue
261 }
262 id := e.id()
263 if remove {
264 if submap != nil {
265 if _, ok := submap[id]; ok {
266 delete(submap, id)
267 modified = true
268 }
269 }
270 continue
271 }
272 if submap == nil {
273 submap = make(map[string]entry)
274 }
275
276 if old, ok := submap[id]; ok {
277 e.Creation = old.Creation
278 e.seqNum = old.seqNum
279 } else {
280 e.Creation = now
281 e.seqNum = j.nextSeqNum
282 j.nextSeqNum++
283 }
284 e.LastAccess = now
285 submap[id] = e
286 modified = true
287 }
288
289 if modified {
290 if len(submap) == 0 {
291 delete(j.entries, key)
292 } else {
293 j.entries[key] = submap
294 }
295 }
296 }
297
298
299
300 func canonicalHost(host string) (string, error) {
301 var err error
302 if hasPort(host) {
303 host, _, err = net.SplitHostPort(host)
304 if err != nil {
305 return "", err
306 }
307 }
308
309 host = strings.TrimSuffix(host, ".")
310 encoded, err := toASCII(host)
311 if err != nil {
312 return "", err
313 }
314
315 lower, _ := ascii.ToLower(encoded)
316 return lower, nil
317 }
318
319
320
321 func hasPort(host string) bool {
322 colons := strings.Count(host, ":")
323 if colons == 0 {
324 return false
325 }
326 if colons == 1 {
327 return true
328 }
329 return host[0] == '[' && strings.Contains(host, "]:")
330 }
331
332
333 func jarKey(host string, psl PublicSuffixList) string {
334 if isIP(host) {
335 return host
336 }
337
338 var i int
339 if psl == nil {
340 i = strings.LastIndex(host, ".")
341 if i <= 0 {
342 return host
343 }
344 } else {
345 suffix := psl.PublicSuffix(host)
346 if suffix == host {
347 return host
348 }
349 i = len(host) - len(suffix)
350 if i <= 0 || host[i-1] != '.' {
351
352
353 return host
354 }
355
356
357
358 }
359 prevDot := strings.LastIndex(host[:i-1], ".")
360 return host[prevDot+1:]
361 }
362
363
364 func isIP(host string) bool {
365 return net.ParseIP(host) != nil
366 }
367
368
369
370 func defaultPath(path string) string {
371 if len(path) == 0 || path[0] != '/' {
372 return "/"
373 }
374
375 i := strings.LastIndex(path, "/")
376 if i == 0 {
377 return "/"
378 }
379 return path[:i]
380 }
381
382
383
384
385
386
387
388
389
390
391 func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) {
392 e.Name = c.Name
393
394 if c.Path == "" || c.Path[0] != '/' {
395 e.Path = defPath
396 } else {
397 e.Path = c.Path
398 }
399
400 e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain)
401 if err != nil {
402 return e, false, err
403 }
404
405
406 if c.MaxAge < 0 {
407 return e, true, nil
408 } else if c.MaxAge > 0 {
409 e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
410 e.Persistent = true
411 } else {
412 if c.Expires.IsZero() {
413 e.Expires = endOfTime
414 e.Persistent = false
415 } else {
416 if !c.Expires.After(now) {
417 return e, true, nil
418 }
419 e.Expires = c.Expires
420 e.Persistent = true
421 }
422 }
423
424 e.Value = c.Value
425 e.Secure = c.Secure
426 e.HttpOnly = c.HttpOnly
427
428 switch c.SameSite {
429 case http.SameSiteDefaultMode:
430 e.SameSite = "SameSite"
431 case http.SameSiteStrictMode:
432 e.SameSite = "SameSite=Strict"
433 case http.SameSiteLaxMode:
434 e.SameSite = "SameSite=Lax"
435 }
436
437 return e, false, nil
438 }
439
440 var (
441 errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute")
442 errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute")
443 errNoHostname = errors.New("cookiejar: no host name available (IP only)")
444 )
445
446
447
448
449 var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
450
451
452 func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
453 if domain == "" {
454
455
456 return host, true, nil
457 }
458
459 if isIP(host) {
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476 if host != domain {
477 return "", false, errIllegalDomain
478 }
479
480
481
482
483
484
485
486
487
488
489 return host, true, nil
490 }
491
492
493
494
495 if domain[0] == '.' {
496 domain = domain[1:]
497 }
498
499 if len(domain) == 0 || domain[0] == '.' {
500
501
502 return "", false, errMalformedDomain
503 }
504
505 domain, isASCII := ascii.ToLower(domain)
506 if !isASCII {
507
508 return "", false, errMalformedDomain
509 }
510
511 if domain[len(domain)-1] == '.' {
512
513
514
515
516
517
518 return "", false, errMalformedDomain
519 }
520
521
522 if j.psList != nil {
523 if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) {
524 if host == domain {
525
526
527 return host, true, nil
528 }
529 return "", false, errIllegalDomain
530 }
531 }
532
533
534
535 if host != domain && !hasDotSuffix(host, domain) {
536 return "", false, errIllegalDomain
537 }
538
539 return domain, false, nil
540 }
541
View as plain text