1
2
3
4
5 package elliptic
6
7 import (
8 "bytes"
9 "crypto/rand"
10 "encoding/hex"
11 "math/big"
12 "testing"
13 )
14
15
16
17
18
19 func genericParamsForCurve(c Curve) *CurveParams {
20 d := *(c.Params())
21 return &d
22 }
23
24 func testAllCurves(t *testing.T, f func(*testing.T, Curve)) {
25 tests := []struct {
26 name string
27 curve Curve
28 }{
29 {"P256", P256()},
30 {"P256/Params", genericParamsForCurve(P256())},
31 {"P224", P224()},
32 {"P224/Params", genericParamsForCurve(P224())},
33 {"P384", P384()},
34 {"P384/Params", genericParamsForCurve(P384())},
35 {"P521", P521()},
36 {"P521/Params", genericParamsForCurve(P521())},
37 }
38 if testing.Short() {
39 tests = tests[:1]
40 }
41 for _, test := range tests {
42 curve := test.curve
43 t.Run(test.name, func(t *testing.T) {
44 t.Parallel()
45 f(t, curve)
46 })
47 }
48 }
49
50 func TestOnCurve(t *testing.T) {
51 testAllCurves(t, func(t *testing.T, curve Curve) {
52 if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) {
53 t.Error("basepoint is not on the curve")
54 }
55 })
56 }
57
58 func TestOffCurve(t *testing.T) {
59 testAllCurves(t, func(t *testing.T, curve Curve) {
60 x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1)
61 if curve.IsOnCurve(x, y) {
62 t.Errorf("point off curve is claimed to be on the curve")
63 }
64 b := Marshal(curve, x, y)
65 x1, y1 := Unmarshal(curve, b)
66 if x1 != nil || y1 != nil {
67 t.Errorf("unmarshaling a point not on the curve succeeded")
68 }
69 })
70 }
71
72 func TestInfinity(t *testing.T) {
73 testAllCurves(t, testInfinity)
74 }
75
76 func testInfinity(t *testing.T, curve Curve) {
77 _, x, y, _ := GenerateKey(curve, rand.Reader)
78 x, y = curve.ScalarMult(x, y, curve.Params().N.Bytes())
79 if x.Sign() != 0 || y.Sign() != 0 {
80 t.Errorf("x^q != ∞")
81 }
82
83 x, y = curve.ScalarBaseMult([]byte{0})
84 if x.Sign() != 0 || y.Sign() != 0 {
85 t.Errorf("b^0 != ∞")
86 x.SetInt64(0)
87 y.SetInt64(0)
88 }
89
90 x2, y2 := curve.Double(x, y)
91 if x2.Sign() != 0 || y2.Sign() != 0 {
92 t.Errorf("2∞ != ∞")
93 }
94
95 baseX := curve.Params().Gx
96 baseY := curve.Params().Gy
97
98 x3, y3 := curve.Add(baseX, baseY, x, y)
99 if x3.Cmp(baseX) != 0 || y3.Cmp(baseY) != 0 {
100 t.Errorf("x+∞ != x")
101 }
102
103 x4, y4 := curve.Add(x, y, baseX, baseY)
104 if x4.Cmp(baseX) != 0 || y4.Cmp(baseY) != 0 {
105 t.Errorf("∞+x != x")
106 }
107
108 if curve.IsOnCurve(x, y) {
109 t.Errorf("IsOnCurve(∞) == true")
110 }
111
112 if xx, yy := Unmarshal(curve, Marshal(curve, x, y)); xx != nil || yy != nil {
113 t.Errorf("Unmarshal(Marshal(∞)) did not return an error")
114 }
115
116
117 if xx, yy := Unmarshal(curve, []byte{0x00}); xx != nil || yy != nil {
118 t.Errorf("Unmarshal(∞) did not return an error")
119 }
120 }
121
122 func TestMarshal(t *testing.T) {
123 testAllCurves(t, func(t *testing.T, curve Curve) {
124 _, x, y, err := GenerateKey(curve, rand.Reader)
125 if err != nil {
126 t.Fatal(err)
127 }
128 serialized := Marshal(curve, x, y)
129 xx, yy := Unmarshal(curve, serialized)
130 if xx == nil {
131 t.Fatal("failed to unmarshal")
132 }
133 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
134 t.Fatal("unmarshal returned different values")
135 }
136 })
137 }
138
139 func TestUnmarshalToLargeCoordinates(t *testing.T) {
140
141 testAllCurves(t, testUnmarshalToLargeCoordinates)
142 }
143
144 func testUnmarshalToLargeCoordinates(t *testing.T, curve Curve) {
145 p := curve.Params().P
146 byteLen := (p.BitLen() + 7) / 8
147
148
149
150
151 x := new(big.Int).Add(p, big.NewInt(5))
152 y := curve.Params().polynomial(x)
153 y.ModSqrt(y, p)
154
155 invalid := make([]byte, byteLen*2+1)
156 invalid[0] = 4
157 x.FillBytes(invalid[1 : 1+byteLen])
158 y.FillBytes(invalid[1+byteLen:])
159
160 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
161 t.Errorf("Unmarshal accepts invalid X coordinate")
162 }
163
164 if curve == p256 {
165
166
167 x, _ = new(big.Int).SetString("31931927535157963707678568152204072984517581467226068221761862915403492091210", 10)
168 y, _ = new(big.Int).SetString("5208467867388784005506817585327037698770365050895731383201516607147", 10)
169 y.Add(y, p)
170
171 if p.Cmp(y) > 0 || y.BitLen() != 256 {
172 t.Fatal("y not within expected range")
173 }
174
175
176 x.FillBytes(invalid[1 : 1+byteLen])
177 y.FillBytes(invalid[1+byteLen:])
178
179 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
180 t.Errorf("Unmarshal accepts invalid Y coordinate")
181 }
182 }
183 }
184
185
186
187
188 func TestInvalidCoordinates(t *testing.T) {
189 testAllCurves(t, testInvalidCoordinates)
190 }
191
192 func testInvalidCoordinates(t *testing.T, curve Curve) {
193 checkIsOnCurveFalse := func(name string, x, y *big.Int) {
194 if curve.IsOnCurve(x, y) {
195 t.Errorf("IsOnCurve(%s) unexpectedly returned true", name)
196 }
197 }
198
199 p := curve.Params().P
200 _, x, y, _ := GenerateKey(curve, rand.Reader)
201 xx, yy := new(big.Int), new(big.Int)
202
203
204 xx.Neg(x)
205 checkIsOnCurveFalse("-x, y", xx, y)
206 yy.Neg(y)
207 checkIsOnCurveFalse("x, -y", x, yy)
208
209
210 xx.Sub(x, p)
211 checkIsOnCurveFalse("x-P, y", xx, y)
212 yy.Sub(y, p)
213 checkIsOnCurveFalse("x, y-P", x, yy)
214
215
216 xx.Add(x, p)
217 checkIsOnCurveFalse("x+P, y", xx, y)
218 yy.Add(y, p)
219 checkIsOnCurveFalse("x, y+P", x, yy)
220
221
222 xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535))
223 checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y)
224 yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535))
225 checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy)
226
227
228
229
230
231
232 if yy := new(big.Int).ModSqrt(curve.Params().B, p); yy != nil {
233 if !curve.IsOnCurve(big.NewInt(0), yy) {
234 t.Fatal("(0, mod_sqrt(B)) is not on the curve?")
235 }
236 checkIsOnCurveFalse("P, y", p, yy)
237 }
238 }
239
240 func TestMarshalCompressed(t *testing.T) {
241 t.Run("P-256/03", func(t *testing.T) {
242 data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
243 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
244 y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10)
245 testMarshalCompressed(t, P256(), x, y, data)
246 })
247 t.Run("P-256/02", func(t *testing.T) {
248 data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
249 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
250 y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10)
251 testMarshalCompressed(t, P256(), x, y, data)
252 })
253
254 t.Run("Invalid", func(t *testing.T) {
255 data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535")
256 X, Y := UnmarshalCompressed(P256(), data)
257 if X != nil || Y != nil {
258 t.Error("expected an error for invalid encoding")
259 }
260 })
261
262 if testing.Short() {
263 t.Skip("skipping other curves on short test")
264 }
265
266 testAllCurves(t, func(t *testing.T, curve Curve) {
267 _, x, y, err := GenerateKey(curve, rand.Reader)
268 if err != nil {
269 t.Fatal(err)
270 }
271 testMarshalCompressed(t, curve, x, y, nil)
272 })
273
274 }
275
276 func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) {
277 if !curve.IsOnCurve(x, y) {
278 t.Fatal("invalid test point")
279 }
280 got := MarshalCompressed(curve, x, y)
281 if want != nil && !bytes.Equal(got, want) {
282 t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want)
283 }
284
285 X, Y := UnmarshalCompressed(curve, got)
286 if X == nil || Y == nil {
287 t.Fatalf("UnmarshalCompressed failed unexpectedly")
288 }
289
290 if !curve.IsOnCurve(X, Y) {
291 t.Error("UnmarshalCompressed returned a point not on the curve")
292 }
293 if X.Cmp(x) != 0 || Y.Cmp(y) != 0 {
294 t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y)
295 }
296 }
297
298 func TestLargeIsOnCurve(t *testing.T) {
299 testAllCurves(t, func(t *testing.T, curve Curve) {
300 large := big.NewInt(1)
301 large.Lsh(large, 1000)
302 if curve.IsOnCurve(large, large) {
303 t.Errorf("(2^1000, 2^1000) is reported on the curve")
304 }
305 })
306 }
307
308 func benchmarkAllCurves(t *testing.B, f func(*testing.B, Curve)) {
309 tests := []struct {
310 name string
311 curve Curve
312 }{
313 {"P256", P256()},
314 {"P224", P224()},
315 {"P384", P384()},
316 {"P521", P521()},
317 }
318 for _, test := range tests {
319 curve := test.curve
320 t.Run(test.name, func(t *testing.B) {
321 f(t, curve)
322 })
323 }
324 }
325
326 func BenchmarkScalarBaseMult(b *testing.B) {
327 benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
328 priv, _, _, _ := GenerateKey(curve, rand.Reader)
329 b.ReportAllocs()
330 b.ResetTimer()
331 for i := 0; i < b.N; i++ {
332 x, _ := curve.ScalarBaseMult(priv)
333
334 priv[0] ^= byte(x.Bits()[0])
335 }
336 })
337 }
338
339 func BenchmarkScalarMult(b *testing.B) {
340 benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
341 _, x, y, _ := GenerateKey(curve, rand.Reader)
342 priv, _, _, _ := GenerateKey(curve, rand.Reader)
343 b.ReportAllocs()
344 b.ResetTimer()
345 for i := 0; i < b.N; i++ {
346 x, y = curve.ScalarMult(x, y, priv)
347 }
348 })
349 }
350
351 func BenchmarkMarshalUnmarshal(b *testing.B) {
352 benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
353 _, x, y, _ := GenerateKey(curve, rand.Reader)
354 b.Run("Uncompressed", func(b *testing.B) {
355 b.ReportAllocs()
356 for i := 0; i < b.N; i++ {
357 buf := Marshal(curve, x, y)
358 xx, yy := Unmarshal(curve, buf)
359 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
360 b.Error("Unmarshal output different from Marshal input")
361 }
362 }
363 })
364 b.Run("Compressed", func(b *testing.B) {
365 b.ReportAllocs()
366 for i := 0; i < b.N; i++ {
367 buf := Marshal(curve, x, y)
368 xx, yy := Unmarshal(curve, buf)
369 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
370 b.Error("Unmarshal output different from Marshal input")
371 }
372 }
373 })
374 })
375 }
376
View as plain text