1
2
3
4
5 package httptest
6
7 import (
8 "bufio"
9 "io"
10 "net"
11 "net/http"
12 "sync"
13 "testing"
14 )
15
16 type newServerFunc func(http.Handler) *Server
17
18 var newServers = map[string]newServerFunc{
19 "NewServer": NewServer,
20 "NewTLSServer": NewTLSServer,
21
22
23
24 "NewServerManual": func(h http.Handler) *Server {
25 ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
26 ts.Start()
27 return ts
28 },
29 "NewTLSServerManual": func(h http.Handler) *Server {
30 ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
31 ts.StartTLS()
32 return ts
33 },
34 }
35
36 func TestServer(t *testing.T) {
37 for _, name := range []string{"NewServer", "NewServerManual"} {
38 t.Run(name, func(t *testing.T) {
39 newServer := newServers[name]
40 t.Run("Server", func(t *testing.T) { testServer(t, newServer) })
41 t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) })
42 t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) })
43 t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) })
44 t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) })
45 })
46 }
47 for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} {
48 t.Run(name, func(t *testing.T) {
49 newServer := newServers[name]
50 t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) })
51 t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) })
52 })
53 }
54 }
55
56 func testServer(t *testing.T, newServer newServerFunc) {
57 ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
58 w.Write([]byte("hello"))
59 }))
60 defer ts.Close()
61 res, err := http.Get(ts.URL)
62 if err != nil {
63 t.Fatal(err)
64 }
65 got, err := io.ReadAll(res.Body)
66 res.Body.Close()
67 if err != nil {
68 t.Fatal(err)
69 }
70 if string(got) != "hello" {
71 t.Errorf("got %q, want hello", string(got))
72 }
73 }
74
75
76 func testGetAfterClose(t *testing.T, newServer newServerFunc) {
77 ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
78 w.Write([]byte("hello"))
79 }))
80
81 res, err := http.Get(ts.URL)
82 if err != nil {
83 t.Fatal(err)
84 }
85 got, err := io.ReadAll(res.Body)
86 if err != nil {
87 t.Fatal(err)
88 }
89 if string(got) != "hello" {
90 t.Fatalf("got %q, want hello", string(got))
91 }
92
93 ts.Close()
94
95 res, err = http.Get(ts.URL)
96 if err == nil {
97 body, _ := io.ReadAll(res.Body)
98 t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body)
99 }
100 }
101
102 func testServerCloseBlocking(t *testing.T, newServer newServerFunc) {
103 ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
104 w.Write([]byte("hello"))
105 }))
106 dial := func() net.Conn {
107 c, err := net.Dial("tcp", ts.Listener.Addr().String())
108 if err != nil {
109 t.Fatal(err)
110 }
111 return c
112 }
113
114
115 cnew := dial()
116 defer cnew.Close()
117
118
119 cidle := dial()
120 defer cidle.Close()
121 cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n"))
122 _, err := http.ReadResponse(bufio.NewReader(cidle), nil)
123 if err != nil {
124 t.Fatal(err)
125 }
126
127 ts.Close()
128 }
129
130
131 func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) {
132 var s *Server
133 s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
134 s.CloseClientConnections()
135 }))
136 defer s.Close()
137 res, err := http.Get(s.URL)
138 if err == nil {
139 res.Body.Close()
140 t.Fatalf("Unexpected response: %#v", res)
141 }
142 }
143
144
145
146 func testServerClient(t *testing.T, newTLSServer newServerFunc) {
147 ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
148 w.Write([]byte("hello"))
149 }))
150 defer ts.Close()
151 client := ts.Client()
152 res, err := client.Get(ts.URL)
153 if err != nil {
154 t.Fatal(err)
155 }
156 got, err := io.ReadAll(res.Body)
157 res.Body.Close()
158 if err != nil {
159 t.Fatal(err)
160 }
161 if string(got) != "hello" {
162 t.Errorf("got %q, want hello", string(got))
163 }
164 }
165
166
167
168 func testServerClientTransportType(t *testing.T, newServer newServerFunc) {
169 ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
170 }))
171 defer ts.Close()
172 client := ts.Client()
173 if _, ok := client.Transport.(*http.Transport); !ok {
174 t.Errorf("got %T, want *http.Transport", client.Transport)
175 }
176 }
177
178
179
180 func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) {
181 ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
182 }))
183 defer ts.Close()
184 client := ts.Client()
185 if _, ok := client.Transport.(*http.Transport); !ok {
186 t.Errorf("got %T, want *http.Transport", client.Transport)
187 }
188 }
189
190 type onlyCloseListener struct {
191 net.Listener
192 }
193
194 func (onlyCloseListener) Close() error { return nil }
195
196
197
198 func TestServerZeroValueClose(t *testing.T) {
199 ts := &Server{
200 Listener: onlyCloseListener{},
201 Config: &http.Server{},
202 }
203
204 ts.Close()
205 }
206
207
208
209 func TestCloseHijackedConnection(t *testing.T) {
210 hijacked := make(chan net.Conn)
211 ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
212 defer close(hijacked)
213 hj, ok := w.(http.Hijacker)
214 if !ok {
215 t.Fatal("failed to hijack")
216 }
217 c, _, err := hj.Hijack()
218 if err != nil {
219 t.Fatal(err)
220 }
221 hijacked <- c
222 }))
223
224 var wg sync.WaitGroup
225 wg.Add(1)
226 go func() {
227 defer wg.Done()
228 req, err := http.NewRequest("GET", ts.URL, nil)
229 if err != nil {
230 t.Log(err)
231 }
232
233 var c http.Client
234 resp, err := c.Do(req)
235 if err != nil {
236 t.Log(err)
237 return
238 }
239 resp.Body.Close()
240 }()
241
242 wg.Add(1)
243 conn := <-hijacked
244 go func(conn net.Conn) {
245 defer wg.Done()
246
247
248 conn.Close()
249 ts.Config.ConnState(conn, http.StateClosed)
250 }(conn)
251
252 wg.Add(1)
253 go func() {
254 defer wg.Done()
255 ts.Close()
256 }()
257 wg.Wait()
258 }
259
260 func TestTLSServerWithHTTP2(t *testing.T) {
261 modes := []struct {
262 name string
263 wantProto string
264 }{
265 {"http1", "HTTP/1.1"},
266 {"http2", "HTTP/2.0"},
267 }
268
269 for _, tt := range modes {
270 t.Run(tt.name, func(t *testing.T) {
271 cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
272 w.Header().Set("X-Proto", r.Proto)
273 }))
274
275 switch tt.name {
276 case "http2":
277 cst.EnableHTTP2 = true
278 cst.StartTLS()
279 default:
280 cst.Start()
281 }
282
283 defer cst.Close()
284
285 res, err := cst.Client().Get(cst.URL)
286 if err != nil {
287 t.Fatalf("Failed to make request: %v", err)
288 }
289 if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w {
290 t.Fatalf("X-Proto header mismatch:\n\tgot: %q\n\twant: %q", g, w)
291 }
292 })
293 }
294 }
295
View as plain text