1
2
3
4
5
6
7 package elliptic
8
9 import (
10 "io"
11 "math/big"
12 "sync"
13 )
14
15
16
17
18
19
20
21
22
23 type Curve interface {
24
25 Params() *CurveParams
26
27 IsOnCurve(x, y *big.Int) bool
28
29 Add(x1, y1, x2, y2 *big.Int) (x, y *big.Int)
30
31 Double(x1, y1 *big.Int) (x, y *big.Int)
32
33 ScalarMult(x1, y1 *big.Int, k []byte) (x, y *big.Int)
34
35
36 ScalarBaseMult(k []byte) (x, y *big.Int)
37 }
38
39 func matchesSpecificCurve(params *CurveParams, available ...Curve) (Curve, bool) {
40 for _, c := range available {
41 if params == c.Params() {
42 return c, true
43 }
44 }
45 return nil, false
46 }
47
48
49
50 type CurveParams struct {
51 P *big.Int
52 N *big.Int
53 B *big.Int
54 Gx, Gy *big.Int
55 BitSize int
56 Name string
57 }
58
59 func (curve *CurveParams) Params() *CurveParams {
60 return curve
61 }
62
63
64
65
66
67
68
69
70
71 func (curve *CurveParams) polynomial(x *big.Int) *big.Int {
72 x3 := new(big.Int).Mul(x, x)
73 x3.Mul(x3, x)
74
75 threeX := new(big.Int).Lsh(x, 1)
76 threeX.Add(threeX, x)
77
78 x3.Sub(x3, threeX)
79 x3.Add(x3, curve.B)
80 x3.Mod(x3, curve.P)
81
82 return x3
83 }
84
85 func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool {
86
87
88 if specific, ok := matchesSpecificCurve(curve, p224, p384, p521); ok {
89 return specific.IsOnCurve(x, y)
90 }
91
92 if x.Sign() < 0 || x.Cmp(curve.P) >= 0 ||
93 y.Sign() < 0 || y.Cmp(curve.P) >= 0 {
94 return false
95 }
96
97
98 y2 := new(big.Int).Mul(y, y)
99 y2.Mod(y2, curve.P)
100
101 return curve.polynomial(x).Cmp(y2) == 0
102 }
103
104
105
106
107 func zForAffine(x, y *big.Int) *big.Int {
108 z := new(big.Int)
109 if x.Sign() != 0 || y.Sign() != 0 {
110 z.SetInt64(1)
111 }
112 return z
113 }
114
115
116
117 func (curve *CurveParams) affineFromJacobian(x, y, z *big.Int) (xOut, yOut *big.Int) {
118 if z.Sign() == 0 {
119 return new(big.Int), new(big.Int)
120 }
121
122 zinv := new(big.Int).ModInverse(z, curve.P)
123 zinvsq := new(big.Int).Mul(zinv, zinv)
124
125 xOut = new(big.Int).Mul(x, zinvsq)
126 xOut.Mod(xOut, curve.P)
127 zinvsq.Mul(zinvsq, zinv)
128 yOut = new(big.Int).Mul(y, zinvsq)
129 yOut.Mod(yOut, curve.P)
130 return
131 }
132
133 func (curve *CurveParams) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) {
134
135
136 if specific, ok := matchesSpecificCurve(curve, p224, p384, p521); ok {
137 return specific.Add(x1, y1, x2, y2)
138 }
139
140 z1 := zForAffine(x1, y1)
141 z2 := zForAffine(x2, y2)
142 return curve.affineFromJacobian(curve.addJacobian(x1, y1, z1, x2, y2, z2))
143 }
144
145
146
147 func (curve *CurveParams) addJacobian(x1, y1, z1, x2, y2, z2 *big.Int) (*big.Int, *big.Int, *big.Int) {
148
149 x3, y3, z3 := new(big.Int), new(big.Int), new(big.Int)
150 if z1.Sign() == 0 {
151 x3.Set(x2)
152 y3.Set(y2)
153 z3.Set(z2)
154 return x3, y3, z3
155 }
156 if z2.Sign() == 0 {
157 x3.Set(x1)
158 y3.Set(y1)
159 z3.Set(z1)
160 return x3, y3, z3
161 }
162
163 z1z1 := new(big.Int).Mul(z1, z1)
164 z1z1.Mod(z1z1, curve.P)
165 z2z2 := new(big.Int).Mul(z2, z2)
166 z2z2.Mod(z2z2, curve.P)
167
168 u1 := new(big.Int).Mul(x1, z2z2)
169 u1.Mod(u1, curve.P)
170 u2 := new(big.Int).Mul(x2, z1z1)
171 u2.Mod(u2, curve.P)
172 h := new(big.Int).Sub(u2, u1)
173 xEqual := h.Sign() == 0
174 if h.Sign() == -1 {
175 h.Add(h, curve.P)
176 }
177 i := new(big.Int).Lsh(h, 1)
178 i.Mul(i, i)
179 j := new(big.Int).Mul(h, i)
180
181 s1 := new(big.Int).Mul(y1, z2)
182 s1.Mul(s1, z2z2)
183 s1.Mod(s1, curve.P)
184 s2 := new(big.Int).Mul(y2, z1)
185 s2.Mul(s2, z1z1)
186 s2.Mod(s2, curve.P)
187 r := new(big.Int).Sub(s2, s1)
188 if r.Sign() == -1 {
189 r.Add(r, curve.P)
190 }
191 yEqual := r.Sign() == 0
192 if xEqual && yEqual {
193 return curve.doubleJacobian(x1, y1, z1)
194 }
195 r.Lsh(r, 1)
196 v := new(big.Int).Mul(u1, i)
197
198 x3.Set(r)
199 x3.Mul(x3, x3)
200 x3.Sub(x3, j)
201 x3.Sub(x3, v)
202 x3.Sub(x3, v)
203 x3.Mod(x3, curve.P)
204
205 y3.Set(r)
206 v.Sub(v, x3)
207 y3.Mul(y3, v)
208 s1.Mul(s1, j)
209 s1.Lsh(s1, 1)
210 y3.Sub(y3, s1)
211 y3.Mod(y3, curve.P)
212
213 z3.Add(z1, z2)
214 z3.Mul(z3, z3)
215 z3.Sub(z3, z1z1)
216 z3.Sub(z3, z2z2)
217 z3.Mul(z3, h)
218 z3.Mod(z3, curve.P)
219
220 return x3, y3, z3
221 }
222
223 func (curve *CurveParams) Double(x1, y1 *big.Int) (*big.Int, *big.Int) {
224
225
226 if specific, ok := matchesSpecificCurve(curve, p224, p384, p521); ok {
227 return specific.Double(x1, y1)
228 }
229
230 z1 := zForAffine(x1, y1)
231 return curve.affineFromJacobian(curve.doubleJacobian(x1, y1, z1))
232 }
233
234
235
236 func (curve *CurveParams) doubleJacobian(x, y, z *big.Int) (*big.Int, *big.Int, *big.Int) {
237
238 delta := new(big.Int).Mul(z, z)
239 delta.Mod(delta, curve.P)
240 gamma := new(big.Int).Mul(y, y)
241 gamma.Mod(gamma, curve.P)
242 alpha := new(big.Int).Sub(x, delta)
243 if alpha.Sign() == -1 {
244 alpha.Add(alpha, curve.P)
245 }
246 alpha2 := new(big.Int).Add(x, delta)
247 alpha.Mul(alpha, alpha2)
248 alpha2.Set(alpha)
249 alpha.Lsh(alpha, 1)
250 alpha.Add(alpha, alpha2)
251
252 beta := alpha2.Mul(x, gamma)
253
254 x3 := new(big.Int).Mul(alpha, alpha)
255 beta8 := new(big.Int).Lsh(beta, 3)
256 beta8.Mod(beta8, curve.P)
257 x3.Sub(x3, beta8)
258 if x3.Sign() == -1 {
259 x3.Add(x3, curve.P)
260 }
261 x3.Mod(x3, curve.P)
262
263 z3 := new(big.Int).Add(y, z)
264 z3.Mul(z3, z3)
265 z3.Sub(z3, gamma)
266 if z3.Sign() == -1 {
267 z3.Add(z3, curve.P)
268 }
269 z3.Sub(z3, delta)
270 if z3.Sign() == -1 {
271 z3.Add(z3, curve.P)
272 }
273 z3.Mod(z3, curve.P)
274
275 beta.Lsh(beta, 2)
276 beta.Sub(beta, x3)
277 if beta.Sign() == -1 {
278 beta.Add(beta, curve.P)
279 }
280 y3 := alpha.Mul(alpha, beta)
281
282 gamma.Mul(gamma, gamma)
283 gamma.Lsh(gamma, 3)
284 gamma.Mod(gamma, curve.P)
285
286 y3.Sub(y3, gamma)
287 if y3.Sign() == -1 {
288 y3.Add(y3, curve.P)
289 }
290 y3.Mod(y3, curve.P)
291
292 return x3, y3, z3
293 }
294
295 func (curve *CurveParams) ScalarMult(Bx, By *big.Int, k []byte) (*big.Int, *big.Int) {
296
297
298 if specific, ok := matchesSpecificCurve(curve, p224, p256, p384, p521); ok {
299 return specific.ScalarMult(Bx, By, k)
300 }
301
302 Bz := new(big.Int).SetInt64(1)
303 x, y, z := new(big.Int), new(big.Int), new(big.Int)
304
305 for _, byte := range k {
306 for bitNum := 0; bitNum < 8; bitNum++ {
307 x, y, z = curve.doubleJacobian(x, y, z)
308 if byte&0x80 == 0x80 {
309 x, y, z = curve.addJacobian(Bx, By, Bz, x, y, z)
310 }
311 byte <<= 1
312 }
313 }
314
315 return curve.affineFromJacobian(x, y, z)
316 }
317
318 func (curve *CurveParams) ScalarBaseMult(k []byte) (*big.Int, *big.Int) {
319
320
321 if specific, ok := matchesSpecificCurve(curve, p224, p256, p384, p521); ok {
322 return specific.ScalarBaseMult(k)
323 }
324
325 return curve.ScalarMult(curve.Gx, curve.Gy, k)
326 }
327
328 var mask = []byte{0xff, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f}
329
330
331
332 func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err error) {
333 N := curve.Params().N
334 bitSize := N.BitLen()
335 byteLen := (bitSize + 7) / 8
336 priv = make([]byte, byteLen)
337
338 for x == nil {
339 _, err = io.ReadFull(rand, priv)
340 if err != nil {
341 return
342 }
343
344
345 priv[0] &= mask[bitSize%8]
346
347
348 priv[1] ^= 0x42
349
350
351 if new(big.Int).SetBytes(priv).Cmp(N) >= 0 {
352 continue
353 }
354
355 x, y = curve.ScalarBaseMult(priv)
356 }
357 return
358 }
359
360
361
362
363 func Marshal(curve Curve, x, y *big.Int) []byte {
364 byteLen := (curve.Params().BitSize + 7) / 8
365
366 ret := make([]byte, 1+2*byteLen)
367 ret[0] = 4
368
369 x.FillBytes(ret[1 : 1+byteLen])
370 y.FillBytes(ret[1+byteLen : 1+2*byteLen])
371
372 return ret
373 }
374
375
376
377
378 func MarshalCompressed(curve Curve, x, y *big.Int) []byte {
379 byteLen := (curve.Params().BitSize + 7) / 8
380 compressed := make([]byte, 1+byteLen)
381 compressed[0] = byte(y.Bit(0)) | 2
382 x.FillBytes(compressed[1:])
383 return compressed
384 }
385
386
387
388
389 func Unmarshal(curve Curve, data []byte) (x, y *big.Int) {
390 byteLen := (curve.Params().BitSize + 7) / 8
391 if len(data) != 1+2*byteLen {
392 return nil, nil
393 }
394 if data[0] != 4 {
395 return nil, nil
396 }
397 p := curve.Params().P
398 x = new(big.Int).SetBytes(data[1 : 1+byteLen])
399 y = new(big.Int).SetBytes(data[1+byteLen:])
400 if x.Cmp(p) >= 0 || y.Cmp(p) >= 0 {
401 return nil, nil
402 }
403 if !curve.IsOnCurve(x, y) {
404 return nil, nil
405 }
406 return
407 }
408
409
410
411
412 func UnmarshalCompressed(curve Curve, data []byte) (x, y *big.Int) {
413 byteLen := (curve.Params().BitSize + 7) / 8
414 if len(data) != 1+byteLen {
415 return nil, nil
416 }
417 if data[0] != 2 && data[0] != 3 {
418 return nil, nil
419 }
420 p := curve.Params().P
421 x = new(big.Int).SetBytes(data[1:])
422 if x.Cmp(p) >= 0 {
423 return nil, nil
424 }
425
426 y = curve.Params().polynomial(x)
427 y = y.ModSqrt(y, p)
428 if y == nil {
429 return nil, nil
430 }
431 if byte(y.Bit(0)) != data[0]&1 {
432 y.Neg(y).Mod(y, p)
433 }
434 if !curve.IsOnCurve(x, y) {
435 return nil, nil
436 }
437 return
438 }
439
440 var initonce sync.Once
441
442 func initAll() {
443 initP224()
444 initP256()
445 initP384()
446 initP521()
447 }
448
449
450
451
452
453
454
455
456 func P224() Curve {
457 initonce.Do(initAll)
458 return p224
459 }
460
461
462
463
464
465
466
467
468
469 func P256() Curve {
470 initonce.Do(initAll)
471 return p256
472 }
473
474
475
476
477
478
479
480
481 func P384() Curve {
482 initonce.Do(initAll)
483 return p384
484 }
485
486
487
488
489
490
491
492
493 func P521() Curve {
494 initonce.Do(initAll)
495 return p521
496 }
497
View as plain text