1
2
3
4
5
6 package cookiejar
7
8 import (
9 "errors"
10 "fmt"
11 "net"
12 "net/http"
13 "net/http/internal/ascii"
14 "net/url"
15 "sort"
16 "strings"
17 "sync"
18 "time"
19 )
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35 type PublicSuffixList interface {
36
37
38
39
40
41 PublicSuffix(domain string) string
42
43
44
45
46 String() string
47 }
48
49
50 type Options struct {
51
52
53
54
55
56
57 PublicSuffixList PublicSuffixList
58 }
59
60
61 type Jar struct {
62 psList PublicSuffixList
63
64
65 mu sync.Mutex
66
67
68
69 entries map[string]map[string]entry
70
71
72
73 nextSeqNum uint64
74 }
75
76
77
78 func New(o *Options) (*Jar, error) {
79 jar := &Jar{
80 entries: make(map[string]map[string]entry),
81 }
82 if o != nil {
83 jar.psList = o.PublicSuffixList
84 }
85 return jar, nil
86 }
87
88
89
90
91
92 type entry struct {
93 Name string
94 Value string
95 Domain string
96 Path string
97 SameSite string
98 Secure bool
99 HttpOnly bool
100 Persistent bool
101 HostOnly bool
102 Expires time.Time
103 Creation time.Time
104 LastAccess time.Time
105
106
107
108
109 seqNum uint64
110 }
111
112
113 func (e *entry) id() string {
114 return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name)
115 }
116
117
118
119
120 func (e *entry) shouldSend(https bool, host, path string) bool {
121 return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure)
122 }
123
124
125 func (e *entry) domainMatch(host string) bool {
126 if e.Domain == host {
127 return true
128 }
129 return !e.HostOnly && hasDotSuffix(host, e.Domain)
130 }
131
132
133 func (e *entry) pathMatch(requestPath string) bool {
134 if requestPath == e.Path {
135 return true
136 }
137 if strings.HasPrefix(requestPath, e.Path) {
138 if e.Path[len(e.Path)-1] == '/' {
139 return true
140 } else if requestPath[len(e.Path)] == '/' {
141 return true
142 }
143 }
144 return false
145 }
146
147
148 func hasDotSuffix(s, suffix string) bool {
149 return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
150 }
151
152
153
154
155 func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) {
156 return j.cookies(u, time.Now())
157 }
158
159
160 func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
161 if u.Scheme != "http" && u.Scheme != "https" {
162 return cookies
163 }
164 host, err := canonicalHost(u.Host)
165 if err != nil {
166 return cookies
167 }
168 key := jarKey(host, j.psList)
169
170 j.mu.Lock()
171 defer j.mu.Unlock()
172
173 submap := j.entries[key]
174 if submap == nil {
175 return cookies
176 }
177
178 https := u.Scheme == "https"
179 path := u.Path
180 if path == "" {
181 path = "/"
182 }
183
184 modified := false
185 var selected []entry
186 for id, e := range submap {
187 if e.Persistent && !e.Expires.After(now) {
188 delete(submap, id)
189 modified = true
190 continue
191 }
192 if !e.shouldSend(https, host, path) {
193 continue
194 }
195 e.LastAccess = now
196 submap[id] = e
197 selected = append(selected, e)
198 modified = true
199 }
200 if modified {
201 if len(submap) == 0 {
202 delete(j.entries, key)
203 } else {
204 j.entries[key] = submap
205 }
206 }
207
208
209
210 sort.Slice(selected, func(i, j int) bool {
211 s := selected
212 if len(s[i].Path) != len(s[j].Path) {
213 return len(s[i].Path) > len(s[j].Path)
214 }
215 if !s[i].Creation.Equal(s[j].Creation) {
216 return s[i].Creation.Before(s[j].Creation)
217 }
218 return s[i].seqNum < s[j].seqNum
219 })
220 for _, e := range selected {
221 cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value})
222 }
223
224 return cookies
225 }
226
227
228
229
230 func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
231 j.setCookies(u, cookies, time.Now())
232 }
233
234
235 func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
236 if len(cookies) == 0 {
237 return
238 }
239 if u.Scheme != "http" && u.Scheme != "https" {
240 return
241 }
242 host, err := canonicalHost(u.Host)
243 if err != nil {
244 return
245 }
246 key := jarKey(host, j.psList)
247 defPath := defaultPath(u.Path)
248
249 j.mu.Lock()
250 defer j.mu.Unlock()
251
252 submap := j.entries[key]
253
254 modified := false
255 for _, cookie := range cookies {
256 e, remove, err := j.newEntry(cookie, now, defPath, host)
257 if err != nil {
258 continue
259 }
260 id := e.id()
261 if remove {
262 if submap != nil {
263 if _, ok := submap[id]; ok {
264 delete(submap, id)
265 modified = true
266 }
267 }
268 continue
269 }
270 if submap == nil {
271 submap = make(map[string]entry)
272 }
273
274 if old, ok := submap[id]; ok {
275 e.Creation = old.Creation
276 e.seqNum = old.seqNum
277 } else {
278 e.Creation = now
279 e.seqNum = j.nextSeqNum
280 j.nextSeqNum++
281 }
282 e.LastAccess = now
283 submap[id] = e
284 modified = true
285 }
286
287 if modified {
288 if len(submap) == 0 {
289 delete(j.entries, key)
290 } else {
291 j.entries[key] = submap
292 }
293 }
294 }
295
296
297
298 func canonicalHost(host string) (string, error) {
299 var err error
300 if hasPort(host) {
301 host, _, err = net.SplitHostPort(host)
302 if err != nil {
303 return "", err
304 }
305 }
306 if strings.HasSuffix(host, ".") {
307
308 host = host[:len(host)-1]
309 }
310 encoded, err := toASCII(host)
311 if err != nil {
312 return "", err
313 }
314
315 lower, _ := ascii.ToLower(encoded)
316 return lower, nil
317 }
318
319
320
321 func hasPort(host string) bool {
322 colons := strings.Count(host, ":")
323 if colons == 0 {
324 return false
325 }
326 if colons == 1 {
327 return true
328 }
329 return host[0] == '[' && strings.Contains(host, "]:")
330 }
331
332
333 func jarKey(host string, psl PublicSuffixList) string {
334 if isIP(host) {
335 return host
336 }
337
338 var i int
339 if psl == nil {
340 i = strings.LastIndex(host, ".")
341 if i <= 0 {
342 return host
343 }
344 } else {
345 suffix := psl.PublicSuffix(host)
346 if suffix == host {
347 return host
348 }
349 i = len(host) - len(suffix)
350 if i <= 0 || host[i-1] != '.' {
351
352
353 return host
354 }
355
356
357
358 }
359 prevDot := strings.LastIndex(host[:i-1], ".")
360 return host[prevDot+1:]
361 }
362
363
364 func isIP(host string) bool {
365 return net.ParseIP(host) != nil
366 }
367
368
369
370 func defaultPath(path string) string {
371 if len(path) == 0 || path[0] != '/' {
372 return "/"
373 }
374
375 i := strings.LastIndex(path, "/")
376 if i == 0 {
377 return "/"
378 }
379 return path[:i]
380 }
381
382
383
384
385
386
387
388
389
390
391 func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) {
392 e.Name = c.Name
393
394 if c.Path == "" || c.Path[0] != '/' {
395 e.Path = defPath
396 } else {
397 e.Path = c.Path
398 }
399
400 e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain)
401 if err != nil {
402 return e, false, err
403 }
404
405
406 if c.MaxAge < 0 {
407 return e, true, nil
408 } else if c.MaxAge > 0 {
409 e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
410 e.Persistent = true
411 } else {
412 if c.Expires.IsZero() {
413 e.Expires = endOfTime
414 e.Persistent = false
415 } else {
416 if !c.Expires.After(now) {
417 return e, true, nil
418 }
419 e.Expires = c.Expires
420 e.Persistent = true
421 }
422 }
423
424 e.Value = c.Value
425 e.Secure = c.Secure
426 e.HttpOnly = c.HttpOnly
427
428 switch c.SameSite {
429 case http.SameSiteDefaultMode:
430 e.SameSite = "SameSite"
431 case http.SameSiteStrictMode:
432 e.SameSite = "SameSite=Strict"
433 case http.SameSiteLaxMode:
434 e.SameSite = "SameSite=Lax"
435 }
436
437 return e, false, nil
438 }
439
440 var (
441 errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute")
442 errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute")
443 errNoHostname = errors.New("cookiejar: no host name available (IP only)")
444 )
445
446
447
448
449 var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
450
451
452 func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
453 if domain == "" {
454
455
456 return host, true, nil
457 }
458
459 if isIP(host) {
460
461
462
463 return "", false, errNoHostname
464 }
465
466
467
468
469 if domain[0] == '.' {
470 domain = domain[1:]
471 }
472
473 if len(domain) == 0 || domain[0] == '.' {
474
475
476 return "", false, errMalformedDomain
477 }
478
479 domain, isASCII := ascii.ToLower(domain)
480 if !isASCII {
481
482 return "", false, errMalformedDomain
483 }
484
485 if domain[len(domain)-1] == '.' {
486
487
488
489
490
491
492 return "", false, errMalformedDomain
493 }
494
495
496 if j.psList != nil {
497 if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) {
498 if host == domain {
499
500
501 return host, true, nil
502 }
503 return "", false, errIllegalDomain
504 }
505 }
506
507
508
509 if host != domain && !hasDotSuffix(host, domain) {
510 return "", false, errIllegalDomain
511 }
512
513 return domain, false, nil
514 }
515
View as plain text