1
2
3
4
5 package httptest
6
7 import (
8 "bytes"
9 "fmt"
10 "io"
11 "net/http"
12 "net/textproto"
13 "strconv"
14 "strings"
15
16 "golang.org/x/net/http/httpguts"
17 )
18
19
20
21 type ResponseRecorder struct {
22
23
24
25
26
27
28 Code int
29
30
31
32
33
34
35
36 HeaderMap http.Header
37
38
39
40 Body *bytes.Buffer
41
42
43 Flushed bool
44
45 result *http.Response
46 snapHeader http.Header
47 wroteHeader bool
48 }
49
50
51 func NewRecorder() *ResponseRecorder {
52 return &ResponseRecorder{
53 HeaderMap: make(http.Header),
54 Body: new(bytes.Buffer),
55 Code: 200,
56 }
57 }
58
59
60
61 const DefaultRemoteAddr = "1.2.3.4"
62
63
64
65
66
67 func (rw *ResponseRecorder) Header() http.Header {
68 m := rw.HeaderMap
69 if m == nil {
70 m = make(http.Header)
71 rw.HeaderMap = m
72 }
73 return m
74 }
75
76
77
78
79
80
81
82
83 func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
84 if rw.wroteHeader {
85 return
86 }
87 if len(str) > 512 {
88 str = str[:512]
89 }
90
91 m := rw.Header()
92
93 _, hasType := m["Content-Type"]
94 hasTE := m.Get("Transfer-Encoding") != ""
95 if !hasType && !hasTE {
96 if b == nil {
97 b = []byte(str)
98 }
99 m.Set("Content-Type", http.DetectContentType(b))
100 }
101
102 rw.WriteHeader(200)
103 }
104
105
106
107 func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
108 rw.writeHeader(buf, "")
109 if rw.Body != nil {
110 rw.Body.Write(buf)
111 }
112 return len(buf), nil
113 }
114
115
116
117 func (rw *ResponseRecorder) WriteString(str string) (int, error) {
118 rw.writeHeader(nil, str)
119 if rw.Body != nil {
120 rw.Body.WriteString(str)
121 }
122 return len(str), nil
123 }
124
125 func checkWriteHeaderCode(code int) {
126
127
128
129
130
131
132
133
134
135
136
137 if code < 100 || code > 999 {
138 panic(fmt.Sprintf("invalid WriteHeader code %v", code))
139 }
140 }
141
142
143 func (rw *ResponseRecorder) WriteHeader(code int) {
144 if rw.wroteHeader {
145 return
146 }
147
148 checkWriteHeaderCode(code)
149 rw.Code = code
150 rw.wroteHeader = true
151 if rw.HeaderMap == nil {
152 rw.HeaderMap = make(http.Header)
153 }
154 rw.snapHeader = rw.HeaderMap.Clone()
155 }
156
157
158
159 func (rw *ResponseRecorder) Flush() {
160 if !rw.wroteHeader {
161 rw.WriteHeader(200)
162 }
163 rw.Flushed = true
164 }
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181 func (rw *ResponseRecorder) Result() *http.Response {
182 if rw.result != nil {
183 return rw.result
184 }
185 if rw.snapHeader == nil {
186 rw.snapHeader = rw.HeaderMap.Clone()
187 }
188 res := &http.Response{
189 Proto: "HTTP/1.1",
190 ProtoMajor: 1,
191 ProtoMinor: 1,
192 StatusCode: rw.Code,
193 Header: rw.snapHeader,
194 }
195 rw.result = res
196 if res.StatusCode == 0 {
197 res.StatusCode = 200
198 }
199 res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
200 if rw.Body != nil {
201 res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
202 } else {
203 res.Body = http.NoBody
204 }
205 res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
206
207 if trailers, ok := rw.snapHeader["Trailer"]; ok {
208 res.Trailer = make(http.Header, len(trailers))
209 for _, k := range trailers {
210 k = http.CanonicalHeaderKey(k)
211 if !httpguts.ValidTrailerHeader(k) {
212
213 continue
214 }
215 vv, ok := rw.HeaderMap[k]
216 if !ok {
217 continue
218 }
219 vv2 := make([]string, len(vv))
220 copy(vv2, vv)
221 res.Trailer[k] = vv2
222 }
223 }
224 for k, vv := range rw.HeaderMap {
225 if !strings.HasPrefix(k, http.TrailerPrefix) {
226 continue
227 }
228 if res.Trailer == nil {
229 res.Trailer = make(http.Header)
230 }
231 for _, v := range vv {
232 res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
233 }
234 }
235 return res
236 }
237
238
239
240
241
242
243 func parseContentLength(cl string) int64 {
244 cl = textproto.TrimString(cl)
245 if cl == "" {
246 return -1
247 }
248 n, err := strconv.ParseUint(cl, 10, 63)
249 if err != nil {
250 return -1
251 }
252 return int64(n)
253 }
254
View as plain text