Source file
src/net/http/transport_internal_test.go
1
2
3
4
5
6
7 package http
8
9 import (
10 "bytes"
11 "crypto/tls"
12 "errors"
13 "io"
14 "net"
15 "net/http/internal/testcert"
16 "strings"
17 "testing"
18 )
19
20
21 func TestTransportPersistConnReadLoopEOF(t *testing.T) {
22 ln := newLocalListener(t)
23 defer ln.Close()
24
25 connc := make(chan net.Conn, 1)
26 go func() {
27 defer close(connc)
28 c, err := ln.Accept()
29 if err != nil {
30 t.Error(err)
31 return
32 }
33 connc <- c
34 }()
35
36 tr := new(Transport)
37 req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
38 req = req.WithT(t)
39 treq := &transportRequest{Request: req}
40 cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
41 pc, err := tr.getConn(treq, cm)
42 if err != nil {
43 t.Fatal(err)
44 }
45 defer pc.close(errors.New("test over"))
46
47 conn := <-connc
48 if conn == nil {
49
50 return
51 }
52 conn.Close()
53
54 _, err = pc.roundTrip(treq)
55 if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle {
56 t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle, transportReadFromServerError, or nothingWrittenError", err, err)
57 }
58
59 <-pc.closech
60 err = pc.closed
61 if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
62 t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err)
63 }
64 }
65
66 func isNothingWrittenError(err error) bool {
67 _, ok := err.(nothingWrittenError)
68 return ok
69 }
70
71 func isTransportReadFromServerError(err error) bool {
72 _, ok := err.(transportReadFromServerError)
73 return ok
74 }
75
76 func newLocalListener(t *testing.T) net.Listener {
77 ln, err := net.Listen("tcp", "127.0.0.1:0")
78 if err != nil {
79 ln, err = net.Listen("tcp6", "[::1]:0")
80 }
81 if err != nil {
82 t.Fatal(err)
83 }
84 return ln
85 }
86
87 func dummyRequest(method string) *Request {
88 req, err := NewRequest(method, "http://fake.tld/", nil)
89 if err != nil {
90 panic(err)
91 }
92 return req
93 }
94 func dummyRequestWithBody(method string) *Request {
95 req, err := NewRequest(method, "http://fake.tld/", strings.NewReader("foo"))
96 if err != nil {
97 panic(err)
98 }
99 return req
100 }
101
102 func dummyRequestWithBodyNoGetBody(method string) *Request {
103 req := dummyRequestWithBody(method)
104 req.GetBody = nil
105 return req
106 }
107
108
109 type issue22091Error struct{}
110
111 func (issue22091Error) IsHTTP2NoCachedConnError() {}
112 func (issue22091Error) Error() string { return "issue22091Error" }
113
114 func TestTransportShouldRetryRequest(t *testing.T) {
115 tests := []struct {
116 pc *persistConn
117 req *Request
118
119 err error
120 want bool
121 }{
122 0: {
123 pc: &persistConn{reused: false},
124 req: dummyRequest("POST"),
125 err: nothingWrittenError{},
126 want: false,
127 },
128 1: {
129 pc: &persistConn{reused: true},
130 req: dummyRequest("POST"),
131 err: nothingWrittenError{},
132 want: true,
133 },
134 2: {
135 pc: &persistConn{reused: true},
136 req: dummyRequest("POST"),
137 err: http2ErrNoCachedConn,
138 want: true,
139 },
140 3: {
141 pc: nil,
142 req: nil,
143 err: issue22091Error{},
144 want: true,
145 },
146 4: {
147 pc: &persistConn{reused: true},
148 req: dummyRequest("POST"),
149 err: errMissingHost,
150 want: false,
151 },
152 5: {
153 pc: &persistConn{reused: true},
154 req: dummyRequest("POST"),
155 err: transportReadFromServerError{},
156 want: false,
157 },
158 6: {
159 pc: &persistConn{reused: true},
160 req: dummyRequest("GET"),
161 err: transportReadFromServerError{},
162 want: true,
163 },
164 7: {
165 pc: &persistConn{reused: true},
166 req: dummyRequest("GET"),
167 err: errServerClosedIdle,
168 want: true,
169 },
170 8: {
171 pc: &persistConn{reused: true},
172 req: dummyRequestWithBody("POST"),
173 err: nothingWrittenError{},
174 want: true,
175 },
176 9: {
177 pc: &persistConn{reused: true},
178 req: dummyRequestWithBodyNoGetBody("POST"),
179 err: nothingWrittenError{},
180 want: false,
181 },
182 }
183 for i, tt := range tests {
184 got := tt.pc.shouldRetryRequest(tt.req, tt.err)
185 if got != tt.want {
186 t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want)
187 }
188 }
189 }
190
191 type roundTripFunc func(r *Request) (*Response, error)
192
193 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) {
194 return f(r)
195 }
196
197
198 func TestTransportBodyAltRewind(t *testing.T) {
199 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
200 if err != nil {
201 t.Fatal(err)
202 }
203 ln := newLocalListener(t)
204 defer ln.Close()
205
206 go func() {
207 tln := tls.NewListener(ln, &tls.Config{
208 NextProtos: []string{"foo"},
209 Certificates: []tls.Certificate{cert},
210 })
211 for i := 0; i < 2; i++ {
212 sc, err := tln.Accept()
213 if err != nil {
214 t.Error(err)
215 return
216 }
217 if err := sc.(*tls.Conn).Handshake(); err != nil {
218 t.Error(err)
219 return
220 }
221 sc.Close()
222 }
223 }()
224
225 addr := ln.Addr().String()
226 req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request"))
227 roundTripped := false
228 tr := &Transport{
229 DisableKeepAlives: true,
230 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
231 "foo": func(authority string, c *tls.Conn) RoundTripper {
232 return roundTripFunc(func(r *Request) (*Response, error) {
233 n, _ := io.Copy(io.Discard, r.Body)
234 if n == 0 {
235 t.Error("body length is zero")
236 }
237 if roundTripped {
238 return &Response{
239 Body: NoBody,
240 StatusCode: 200,
241 }, nil
242 }
243 roundTripped = true
244 return nil, http2noCachedConnError{}
245 })
246 },
247 },
248 DialTLS: func(_, _ string) (net.Conn, error) {
249 tc, err := tls.Dial("tcp", addr, &tls.Config{
250 InsecureSkipVerify: true,
251 NextProtos: []string{"foo"},
252 })
253 if err != nil {
254 return nil, err
255 }
256 if err := tc.Handshake(); err != nil {
257 return nil, err
258 }
259 return tc, nil
260 },
261 }
262 c := &Client{Transport: tr}
263 _, err = c.Do(req)
264 if err != nil {
265 t.Error(err)
266 }
267 }
268
View as plain text