1
2
3
4
5
6
7 package tls
8
9
10
11
12
13
14 import (
15 "bytes"
16 "context"
17 "crypto"
18 "crypto/ecdsa"
19 "crypto/ed25519"
20 "crypto/rsa"
21 "crypto/x509"
22 "encoding/pem"
23 "errors"
24 "fmt"
25 "net"
26 "os"
27 "strings"
28 )
29
30
31
32
33
34 func Server(conn net.Conn, config *Config) *Conn {
35 c := &Conn{
36 conn: conn,
37 config: config,
38 }
39 c.handshakeFn = c.serverHandshake
40 return c
41 }
42
43
44
45
46
47 func Client(conn net.Conn, config *Config) *Conn {
48 c := &Conn{
49 conn: conn,
50 config: config,
51 isClient: true,
52 }
53 c.handshakeFn = c.clientHandshake
54 return c
55 }
56
57
58 type listener struct {
59 net.Listener
60 config *Config
61 }
62
63
64
65 func (l *listener) Accept() (net.Conn, error) {
66 c, err := l.Listener.Accept()
67 if err != nil {
68 return nil, err
69 }
70 return Server(c, l.config), nil
71 }
72
73
74
75
76
77 func NewListener(inner net.Listener, config *Config) net.Listener {
78 l := new(listener)
79 l.Listener = inner
80 l.config = config
81 return l
82 }
83
84
85
86
87
88 func Listen(network, laddr string, config *Config) (net.Listener, error) {
89 if config == nil || len(config.Certificates) == 0 &&
90 config.GetCertificate == nil && config.GetConfigForClient == nil {
91 return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
92 }
93 l, err := net.Listen(network, laddr)
94 if err != nil {
95 return nil, err
96 }
97 return NewListener(l, config), nil
98 }
99
100 type timeoutError struct{}
101
102 func (timeoutError) Error() string { return "tls: DialWithDialer timed out" }
103 func (timeoutError) Timeout() bool { return true }
104 func (timeoutError) Temporary() bool { return true }
105
106
107
108
109
110
111
112
113
114
115
116 func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
117 return dial(context.Background(), dialer, network, addr, config)
118 }
119
120 func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
121 if netDialer.Timeout != 0 {
122 var cancel context.CancelFunc
123 ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
124 defer cancel()
125 }
126
127 if !netDialer.Deadline.IsZero() {
128 var cancel context.CancelFunc
129 ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
130 defer cancel()
131 }
132
133 rawConn, err := netDialer.DialContext(ctx, network, addr)
134 if err != nil {
135 return nil, err
136 }
137
138 colonPos := strings.LastIndex(addr, ":")
139 if colonPos == -1 {
140 colonPos = len(addr)
141 }
142 hostname := addr[:colonPos]
143
144 if config == nil {
145 config = defaultConfig()
146 }
147
148
149 if config.ServerName == "" {
150
151 c := config.Clone()
152 c.ServerName = hostname
153 config = c
154 }
155
156 conn := Client(rawConn, config)
157 if err := conn.HandshakeContext(ctx); err != nil {
158 rawConn.Close()
159 return nil, err
160 }
161 return conn, nil
162 }
163
164
165
166
167
168
169
170 func Dial(network, addr string, config *Config) (*Conn, error) {
171 return DialWithDialer(new(net.Dialer), network, addr, config)
172 }
173
174
175
176 type Dialer struct {
177
178
179
180 NetDialer *net.Dialer
181
182
183
184
185
186 Config *Config
187 }
188
189
190
191
192
193
194
195
196 func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
197 return d.DialContext(context.Background(), network, addr)
198 }
199
200 func (d *Dialer) netDialer() *net.Dialer {
201 if d.NetDialer != nil {
202 return d.NetDialer
203 }
204 return new(net.Dialer)
205 }
206
207
208
209
210
211
212
213
214
215
216 func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
217 c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
218 if err != nil {
219
220 return nil, err
221 }
222 return c, nil
223 }
224
225
226
227
228
229
230 func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
231 certPEMBlock, err := os.ReadFile(certFile)
232 if err != nil {
233 return Certificate{}, err
234 }
235 keyPEMBlock, err := os.ReadFile(keyFile)
236 if err != nil {
237 return Certificate{}, err
238 }
239 return X509KeyPair(certPEMBlock, keyPEMBlock)
240 }
241
242
243
244
245 func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
246 fail := func(err error) (Certificate, error) { return Certificate{}, err }
247
248 var cert Certificate
249 var skippedBlockTypes []string
250 for {
251 var certDERBlock *pem.Block
252 certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
253 if certDERBlock == nil {
254 break
255 }
256 if certDERBlock.Type == "CERTIFICATE" {
257 cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
258 } else {
259 skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
260 }
261 }
262
263 if len(cert.Certificate) == 0 {
264 if len(skippedBlockTypes) == 0 {
265 return fail(errors.New("tls: failed to find any PEM data in certificate input"))
266 }
267 if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
268 return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
269 }
270 return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
271 }
272
273 skippedBlockTypes = skippedBlockTypes[:0]
274 var keyDERBlock *pem.Block
275 for {
276 keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
277 if keyDERBlock == nil {
278 if len(skippedBlockTypes) == 0 {
279 return fail(errors.New("tls: failed to find any PEM data in key input"))
280 }
281 if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
282 return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
283 }
284 return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
285 }
286 if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
287 break
288 }
289 skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
290 }
291
292
293
294 x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
295 if err != nil {
296 return fail(err)
297 }
298
299 cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
300 if err != nil {
301 return fail(err)
302 }
303
304 switch pub := x509Cert.PublicKey.(type) {
305 case *rsa.PublicKey:
306 priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
307 if !ok {
308 return fail(errors.New("tls: private key type does not match public key type"))
309 }
310 if pub.N.Cmp(priv.N) != 0 {
311 return fail(errors.New("tls: private key does not match public key"))
312 }
313 case *ecdsa.PublicKey:
314 priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
315 if !ok {
316 return fail(errors.New("tls: private key type does not match public key type"))
317 }
318 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
319 return fail(errors.New("tls: private key does not match public key"))
320 }
321 case ed25519.PublicKey:
322 priv, ok := cert.PrivateKey.(ed25519.PrivateKey)
323 if !ok {
324 return fail(errors.New("tls: private key type does not match public key type"))
325 }
326 if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) {
327 return fail(errors.New("tls: private key does not match public key"))
328 }
329 default:
330 return fail(errors.New("tls: unknown public key algorithm"))
331 }
332
333 return cert, nil
334 }
335
336
337
338
339 func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
340 if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
341 return key, nil
342 }
343 if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
344 switch key := key.(type) {
345 case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
346 return key, nil
347 default:
348 return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping")
349 }
350 }
351 if key, err := x509.ParseECPrivateKey(der); err == nil {
352 return key, nil
353 }
354
355 return nil, errors.New("tls: failed to parse private key")
356 }
357
View as plain text