1
2
3
4
5 package httptest
6
7 import (
8 "fmt"
9 "io"
10 "net/http"
11 "testing"
12 )
13
14 func TestRecorder(t *testing.T) {
15 type checkFunc func(*ResponseRecorder) error
16 check := func(fns ...checkFunc) []checkFunc { return fns }
17
18 hasStatus := func(wantCode int) checkFunc {
19 return func(rec *ResponseRecorder) error {
20 if rec.Code != wantCode {
21 return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
22 }
23 return nil
24 }
25 }
26 hasResultStatus := func(want string) checkFunc {
27 return func(rec *ResponseRecorder) error {
28 if rec.Result().Status != want {
29 return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want)
30 }
31 return nil
32 }
33 }
34 hasResultStatusCode := func(wantCode int) checkFunc {
35 return func(rec *ResponseRecorder) error {
36 if rec.Result().StatusCode != wantCode {
37 return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
38 }
39 return nil
40 }
41 }
42 hasResultContents := func(want string) checkFunc {
43 return func(rec *ResponseRecorder) error {
44 contentBytes, err := io.ReadAll(rec.Result().Body)
45 if err != nil {
46 return err
47 }
48 contents := string(contentBytes)
49 if contents != want {
50 return fmt.Errorf("Result().Body = %s; want %s", contents, want)
51 }
52 return nil
53 }
54 }
55 hasContents := func(want string) checkFunc {
56 return func(rec *ResponseRecorder) error {
57 if rec.Body.String() != want {
58 return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
59 }
60 return nil
61 }
62 }
63 hasFlush := func(want bool) checkFunc {
64 return func(rec *ResponseRecorder) error {
65 if rec.Flushed != want {
66 return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
67 }
68 return nil
69 }
70 }
71 hasOldHeader := func(key, want string) checkFunc {
72 return func(rec *ResponseRecorder) error {
73 if got := rec.HeaderMap.Get(key); got != want {
74 return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
75 }
76 return nil
77 }
78 }
79 hasHeader := func(key, want string) checkFunc {
80 return func(rec *ResponseRecorder) error {
81 if got := rec.Result().Header.Get(key); got != want {
82 return fmt.Errorf("final header %s = %q; want %q", key, got, want)
83 }
84 return nil
85 }
86 }
87 hasNotHeaders := func(keys ...string) checkFunc {
88 return func(rec *ResponseRecorder) error {
89 for _, k := range keys {
90 v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
91 if ok {
92 return fmt.Errorf("unexpected header %s with value %q", k, v)
93 }
94 }
95 return nil
96 }
97 }
98 hasTrailer := func(key, want string) checkFunc {
99 return func(rec *ResponseRecorder) error {
100 if got := rec.Result().Trailer.Get(key); got != want {
101 return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
102 }
103 return nil
104 }
105 }
106 hasNotTrailers := func(keys ...string) checkFunc {
107 return func(rec *ResponseRecorder) error {
108 trailers := rec.Result().Trailer
109 for _, k := range keys {
110 _, ok := trailers[http.CanonicalHeaderKey(k)]
111 if ok {
112 return fmt.Errorf("unexpected trailer %s", k)
113 }
114 }
115 return nil
116 }
117 }
118 hasContentLength := func(length int64) checkFunc {
119 return func(rec *ResponseRecorder) error {
120 if got := rec.Result().ContentLength; got != length {
121 return fmt.Errorf("ContentLength = %d; want %d", got, length)
122 }
123 return nil
124 }
125 }
126
127 for _, tt := range [...]struct {
128 name string
129 h func(w http.ResponseWriter, r *http.Request)
130 checks []checkFunc
131 }{
132 {
133 "200 default",
134 func(w http.ResponseWriter, r *http.Request) {},
135 check(hasStatus(200), hasContents("")),
136 },
137 {
138 "first code only",
139 func(w http.ResponseWriter, r *http.Request) {
140 w.WriteHeader(201)
141 w.WriteHeader(202)
142 w.Write([]byte("hi"))
143 },
144 check(hasStatus(201), hasContents("hi")),
145 },
146 {
147 "write sends 200",
148 func(w http.ResponseWriter, r *http.Request) {
149 w.Write([]byte("hi first"))
150 w.WriteHeader(201)
151 w.WriteHeader(202)
152 },
153 check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
154 },
155 {
156 "write string",
157 func(w http.ResponseWriter, r *http.Request) {
158 io.WriteString(w, "hi first")
159 },
160 check(
161 hasStatus(200),
162 hasContents("hi first"),
163 hasFlush(false),
164 hasHeader("Content-Type", "text/plain; charset=utf-8"),
165 ),
166 },
167 {
168 "flush",
169 func(w http.ResponseWriter, r *http.Request) {
170 w.(http.Flusher).Flush()
171 w.WriteHeader(201)
172 },
173 check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
174 },
175 {
176 "Content-Type detection",
177 func(w http.ResponseWriter, r *http.Request) {
178 io.WriteString(w, "<html>")
179 },
180 check(hasHeader("Content-Type", "text/html; charset=utf-8")),
181 },
182 {
183 "no Content-Type detection with Transfer-Encoding",
184 func(w http.ResponseWriter, r *http.Request) {
185 w.Header().Set("Transfer-Encoding", "some encoding")
186 io.WriteString(w, "<html>")
187 },
188 check(hasHeader("Content-Type", "")),
189 },
190 {
191 "no Content-Type detection if set explicitly",
192 func(w http.ResponseWriter, r *http.Request) {
193 w.Header().Set("Content-Type", "some/type")
194 io.WriteString(w, "<html>")
195 },
196 check(hasHeader("Content-Type", "some/type")),
197 },
198 {
199 "Content-Type detection doesn't crash if HeaderMap is nil",
200 func(w http.ResponseWriter, r *http.Request) {
201
202
203
204 w.(*ResponseRecorder).HeaderMap = nil
205 io.WriteString(w, "<html>")
206 },
207 check(hasHeader("Content-Type", "text/html; charset=utf-8")),
208 },
209 {
210 "Header is not changed after write",
211 func(w http.ResponseWriter, r *http.Request) {
212 hdr := w.Header()
213 hdr.Set("Key", "correct")
214 w.WriteHeader(200)
215 hdr.Set("Key", "incorrect")
216 },
217 check(hasHeader("Key", "correct")),
218 },
219 {
220 "Trailer headers are correctly recorded",
221 func(w http.ResponseWriter, r *http.Request) {
222 w.Header().Set("Non-Trailer", "correct")
223 w.Header().Set("Trailer", "Trailer-A")
224 w.Header().Add("Trailer", "Trailer-B")
225 w.Header().Add("Trailer", "Trailer-C")
226 io.WriteString(w, "<html>")
227 w.Header().Set("Non-Trailer", "incorrect")
228 w.Header().Set("Trailer-A", "valuea")
229 w.Header().Set("Trailer-C", "valuec")
230 w.Header().Set("Trailer-NotDeclared", "should be omitted")
231 w.Header().Set("Trailer:Trailer-D", "with prefix")
232 },
233 check(
234 hasStatus(200),
235 hasHeader("Content-Type", "text/html; charset=utf-8"),
236 hasHeader("Non-Trailer", "correct"),
237 hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
238 hasTrailer("Trailer-A", "valuea"),
239 hasTrailer("Trailer-C", "valuec"),
240 hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
241 hasTrailer("Trailer-D", "with prefix"),
242 ),
243 },
244 {
245 "Header set without any write",
246 func(w http.ResponseWriter, r *http.Request) {
247 w.Header().Set("X-Foo", "1")
248
249
250
251
252
253 w.(*ResponseRecorder).Code = 0
254 },
255 check(
256 hasOldHeader("X-Foo", "1"),
257 hasStatus(0),
258 hasHeader("X-Foo", "1"),
259 hasResultStatus("200 OK"),
260 hasResultStatusCode(200),
261 ),
262 },
263 {
264 "HeaderMap vs FinalHeaders",
265 func(w http.ResponseWriter, r *http.Request) {
266 h := w.Header()
267 h.Set("X-Foo", "1")
268 w.Write([]byte("hi"))
269 h.Set("X-Foo", "2")
270 h.Set("X-Bar", "2")
271 },
272 check(
273 hasOldHeader("X-Foo", "2"),
274 hasOldHeader("X-Bar", "2"),
275 hasHeader("X-Foo", "1"),
276 hasNotHeaders("X-Bar"),
277 ),
278 },
279 {
280 "setting Content-Length header",
281 func(w http.ResponseWriter, r *http.Request) {
282 body := "Some body"
283 contentLength := fmt.Sprintf("%d", len(body))
284 w.Header().Set("Content-Length", contentLength)
285 io.WriteString(w, body)
286 },
287 check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
288 },
289 {
290 "nil ResponseRecorder.Body",
291 func(w http.ResponseWriter, r *http.Request) {
292 w.(*ResponseRecorder).Body = nil
293 io.WriteString(w, "hi")
294 },
295 check(hasResultContents("")),
296
297 },
298 } {
299 t.Run(tt.name, func(t *testing.T) {
300 r, _ := http.NewRequest("GET", "http://foo.com/", nil)
301 h := http.HandlerFunc(tt.h)
302 rec := NewRecorder()
303 h.ServeHTTP(rec, r)
304 for _, check := range tt.checks {
305 if err := check(rec); err != nil {
306 t.Error(err)
307 }
308 }
309 })
310 }
311 }
312
313
314 func TestParseContentLength(t *testing.T) {
315 tests := []struct {
316 cl string
317 want int64
318 }{
319 {
320 cl: "3",
321 want: 3,
322 },
323 {
324 cl: "+3",
325 want: -1,
326 },
327 {
328 cl: "-3",
329 want: -1,
330 },
331 {
332
333 cl: "9223372036854775807",
334 want: 9223372036854775807,
335 },
336 {
337 cl: "9223372036854775808",
338 want: -1,
339 },
340 }
341
342 for _, tt := range tests {
343 if got := parseContentLength(tt.cl); got != tt.want {
344 t.Errorf("%q:\n\tgot=%d\n\twant=%d", tt.cl, got, tt.want)
345 }
346 }
347 }
348
349
350
351 func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) {
352 badCodes := []int{
353 -100, 0, 99, 1000, 20000,
354 }
355 for _, badCode := range badCodes {
356 badCode := badCode
357 t.Run(fmt.Sprintf("Code=%d", badCode), func(t *testing.T) {
358 defer func() {
359 if r := recover(); r == nil {
360 t.Fatal("Expected a panic")
361 }
362 }()
363
364 handler := func(rw http.ResponseWriter, _ *http.Request) {
365 rw.WriteHeader(badCode)
366 }
367 r, _ := http.NewRequest("GET", "http://example.org/", nil)
368 rw := NewRecorder()
369 handler(rw, r)
370 })
371 }
372 }
373
View as plain text