1
2
3
4
5 package field
6
7 import (
8 "bytes"
9 "crypto/rand"
10 "encoding/hex"
11 "io"
12 "math/big"
13 "math/bits"
14 mathrand "math/rand"
15 "reflect"
16 "testing"
17 "testing/quick"
18 )
19
20 func (v Element) String() string {
21 return hex.EncodeToString(v.Bytes())
22 }
23
24
25
26 var quickCheckConfig1024 = &quick.Config{MaxCountScale: 1 << 10}
27
28 func generateFieldElement(rand *mathrand.Rand) Element {
29 const maskLow52Bits = (1 << 52) - 1
30 return Element{
31 rand.Uint64() & maskLow52Bits,
32 rand.Uint64() & maskLow52Bits,
33 rand.Uint64() & maskLow52Bits,
34 rand.Uint64() & maskLow52Bits,
35 rand.Uint64() & maskLow52Bits,
36 }
37 }
38
39
40
41 var (
42 weirdLimbs51 = []uint64{
43 0, 0, 0, 0,
44 1,
45 19 - 1,
46 19,
47 0x2aaaaaaaaaaaa,
48 0x5555555555555,
49 (1 << 51) - 20,
50 (1 << 51) - 19,
51 (1 << 51) - 1, (1 << 51) - 1,
52 (1 << 51) - 1, (1 << 51) - 1,
53 }
54 weirdLimbs52 = []uint64{
55 0, 0, 0, 0, 0, 0,
56 1,
57 19 - 1,
58 19,
59 0x2aaaaaaaaaaaa,
60 0x5555555555555,
61 (1 << 51) - 20,
62 (1 << 51) - 19,
63 (1 << 51) - 1, (1 << 51) - 1,
64 (1 << 51) - 1, (1 << 51) - 1,
65 (1 << 51) - 1, (1 << 51) - 1,
66 1 << 51,
67 (1 << 51) + 1,
68 (1 << 52) - 19,
69 (1 << 52) - 1,
70 }
71 )
72
73 func generateWeirdFieldElement(rand *mathrand.Rand) Element {
74 return Element{
75 weirdLimbs52[rand.Intn(len(weirdLimbs52))],
76 weirdLimbs51[rand.Intn(len(weirdLimbs51))],
77 weirdLimbs51[rand.Intn(len(weirdLimbs51))],
78 weirdLimbs51[rand.Intn(len(weirdLimbs51))],
79 weirdLimbs51[rand.Intn(len(weirdLimbs51))],
80 }
81 }
82
83 func (Element) Generate(rand *mathrand.Rand, size int) reflect.Value {
84 if rand.Intn(2) == 0 {
85 return reflect.ValueOf(generateWeirdFieldElement(rand))
86 }
87 return reflect.ValueOf(generateFieldElement(rand))
88 }
89
90
91
92 func isInBounds(x *Element) bool {
93 return bits.Len64(x.l0) <= 52 &&
94 bits.Len64(x.l1) <= 52 &&
95 bits.Len64(x.l2) <= 52 &&
96 bits.Len64(x.l3) <= 52 &&
97 bits.Len64(x.l4) <= 52
98 }
99
100 func TestMultiplyDistributesOverAdd(t *testing.T) {
101 multiplyDistributesOverAdd := func(x, y, z Element) bool {
102
103 t1 := new(Element)
104 t1.Add(&x, &y)
105 t1.Multiply(t1, &z)
106
107
108 t2 := new(Element)
109 t3 := new(Element)
110 t2.Multiply(&x, &z)
111 t3.Multiply(&y, &z)
112 t2.Add(t2, t3)
113
114 return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
115 }
116
117 if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig1024); err != nil {
118 t.Error(err)
119 }
120 }
121
122 func TestMul64to128(t *testing.T) {
123 a := uint64(5)
124 b := uint64(5)
125 r := mul64(a, b)
126 if r.lo != 0x19 || r.hi != 0 {
127 t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
128 }
129
130 a = uint64(18014398509481983)
131 b = uint64(18014398509481983)
132 r = mul64(a, b)
133 if r.lo != 0xff80000000000001 || r.hi != 0xfffffffffff {
134 t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
135 }
136
137 a = uint64(1125899906842661)
138 b = uint64(2097155)
139 r = mul64(a, b)
140 r = addMul64(r, a, b)
141 r = addMul64(r, a, b)
142 r = addMul64(r, a, b)
143 r = addMul64(r, a, b)
144 if r.lo != 16888498990613035 || r.hi != 640 {
145 t.Errorf("wrong answer: %d + %d*(2**64)", r.lo, r.hi)
146 }
147 }
148
149 func TestSetBytesRoundTrip(t *testing.T) {
150 f1 := func(in [32]byte, fe Element) bool {
151 fe.SetBytes(in[:])
152
153
154
155 in[len(in)-1] &= (1 << 7) - 1
156
157 return bytes.Equal(in[:], fe.Bytes()) && isInBounds(&fe)
158 }
159 if err := quick.Check(f1, nil); err != nil {
160 t.Errorf("failed bytes->FE->bytes round-trip: %v", err)
161 }
162
163 f2 := func(fe, r Element) bool {
164 r.SetBytes(fe.Bytes())
165
166
167
168
169 fe.reduce()
170 r.reduce()
171 return fe == r
172 }
173 if err := quick.Check(f2, nil); err != nil {
174 t.Errorf("failed FE->bytes->FE round-trip: %v", err)
175 }
176
177
178 type feRTTest struct {
179 fe Element
180 b []byte
181 }
182 var tests = []feRTTest{
183 {
184 fe: Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676},
185 b: []byte{74, 209, 69, 197, 70, 70, 161, 222, 56, 226, 229, 19, 112, 60, 25, 92, 187, 74, 222, 56, 50, 153, 51, 233, 40, 74, 57, 6, 160, 185, 213, 31},
186 },
187 {
188 fe: Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972},
189 b: []byte{199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122},
190 },
191 }
192
193 for _, tt := range tests {
194 b := tt.fe.Bytes()
195 if !bytes.Equal(b, tt.b) || new(Element).SetBytes(tt.b).Equal(&tt.fe) != 1 {
196 t.Errorf("Failed fixed roundtrip: %v", tt)
197 }
198 }
199 }
200
201 func swapEndianness(buf []byte) []byte {
202 for i := 0; i < len(buf)/2; i++ {
203 buf[i], buf[len(buf)-i-1] = buf[len(buf)-i-1], buf[i]
204 }
205 return buf
206 }
207
208 func TestBytesBigEquivalence(t *testing.T) {
209 f1 := func(in [32]byte, fe, fe1 Element) bool {
210 fe.SetBytes(in[:])
211
212 in[len(in)-1] &= (1 << 7) - 1
213 b := new(big.Int).SetBytes(swapEndianness(in[:]))
214 fe1.fromBig(b)
215
216 if fe != fe1 {
217 return false
218 }
219
220 buf := make([]byte, 32)
221 copy(buf, swapEndianness(fe1.toBig().Bytes()))
222
223 return bytes.Equal(fe.Bytes(), buf) && isInBounds(&fe) && isInBounds(&fe1)
224 }
225 if err := quick.Check(f1, nil); err != nil {
226 t.Error(err)
227 }
228 }
229
230
231 func (v *Element) fromBig(n *big.Int) *Element {
232 if n.BitLen() > 32*8 {
233 panic("edwards25519: invalid field element input size")
234 }
235
236 buf := make([]byte, 0, 32)
237 for _, word := range n.Bits() {
238 for i := 0; i < bits.UintSize; i += 8 {
239 if len(buf) >= cap(buf) {
240 break
241 }
242 buf = append(buf, byte(word))
243 word >>= 8
244 }
245 }
246
247 return v.SetBytes(buf[:32])
248 }
249
250 func (v *Element) fromDecimal(s string) *Element {
251 n, ok := new(big.Int).SetString(s, 10)
252 if !ok {
253 panic("not a valid decimal: " + s)
254 }
255 return v.fromBig(n)
256 }
257
258
259 func (v *Element) toBig() *big.Int {
260 buf := v.Bytes()
261
262 words := make([]big.Word, 32*8/bits.UintSize)
263 for n := range words {
264 for i := 0; i < bits.UintSize; i += 8 {
265 if len(buf) == 0 {
266 break
267 }
268 words[n] |= big.Word(buf[0]) << big.Word(i)
269 buf = buf[1:]
270 }
271 }
272
273 return new(big.Int).SetBits(words)
274 }
275
276 func TestDecimalConstants(t *testing.T) {
277 sqrtM1String := "19681161376707505956807079304988542015446066515923890162744021073123829784752"
278 if exp := new(Element).fromDecimal(sqrtM1String); sqrtM1.Equal(exp) != 1 {
279 t.Errorf("sqrtM1 is %v, expected %v", sqrtM1, exp)
280 }
281
282
283
284
285
286 }
287
288 func TestSetBytesRoundTripEdgeCases(t *testing.T) {
289
290
291
292 }
293
294
295 func TestConsistency(t *testing.T) {
296 var x Element
297 var x2, x2sq Element
298
299 x = Element{1, 1, 1, 1, 1}
300 x2.Multiply(&x, &x)
301 x2sq.Square(&x)
302
303 if x2 != x2sq {
304 t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
305 }
306
307 var bytes [32]byte
308
309 _, err := io.ReadFull(rand.Reader, bytes[:])
310 if err != nil {
311 t.Fatal(err)
312 }
313 x.SetBytes(bytes[:])
314
315 x2.Multiply(&x, &x)
316 x2sq.Square(&x)
317
318 if x2 != x2sq {
319 t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
320 }
321 }
322
323 func TestEqual(t *testing.T) {
324 x := Element{1, 1, 1, 1, 1}
325 y := Element{5, 4, 3, 2, 1}
326
327 eq := x.Equal(&x)
328 if eq != 1 {
329 t.Errorf("wrong about equality")
330 }
331
332 eq = x.Equal(&y)
333 if eq != 0 {
334 t.Errorf("wrong about inequality")
335 }
336 }
337
338 func TestInvert(t *testing.T) {
339 x := Element{1, 1, 1, 1, 1}
340 one := Element{1, 0, 0, 0, 0}
341 var xinv, r Element
342
343 xinv.Invert(&x)
344 r.Multiply(&x, &xinv)
345 r.reduce()
346
347 if one != r {
348 t.Errorf("inversion identity failed, got: %x", r)
349 }
350
351 var bytes [32]byte
352
353 _, err := io.ReadFull(rand.Reader, bytes[:])
354 if err != nil {
355 t.Fatal(err)
356 }
357 x.SetBytes(bytes[:])
358
359 xinv.Invert(&x)
360 r.Multiply(&x, &xinv)
361 r.reduce()
362
363 if one != r {
364 t.Errorf("random inversion identity failed, got: %x for field element %x", r, x)
365 }
366
367 zero := Element{}
368 x.Set(&zero)
369 if xx := xinv.Invert(&x); xx != &xinv {
370 t.Errorf("inverting zero did not return the receiver")
371 } else if xinv.Equal(&zero) != 1 {
372 t.Errorf("inverting zero did not return zero")
373 }
374 }
375
376 func TestSelectSwap(t *testing.T) {
377 a := Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}
378 b := Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}
379
380 var c, d Element
381
382 c.Select(&a, &b, 1)
383 d.Select(&a, &b, 0)
384
385 if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
386 t.Errorf("Select failed")
387 }
388
389 c.Swap(&d, 0)
390
391 if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
392 t.Errorf("Swap failed")
393 }
394
395 c.Swap(&d, 1)
396
397 if c.Equal(&b) != 1 || d.Equal(&a) != 1 {
398 t.Errorf("Swap failed")
399 }
400 }
401
402 func TestMult32(t *testing.T) {
403 mult32EquivalentToMul := func(x Element, y uint32) bool {
404 t1 := new(Element)
405 for i := 0; i < 100; i++ {
406 t1.Mult32(&x, y)
407 }
408
409 ty := new(Element)
410 ty.l0 = uint64(y)
411
412 t2 := new(Element)
413 for i := 0; i < 100; i++ {
414 t2.Multiply(&x, ty)
415 }
416
417 return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
418 }
419
420 if err := quick.Check(mult32EquivalentToMul, quickCheckConfig1024); err != nil {
421 t.Error(err)
422 }
423 }
424
425 func TestSqrtRatio(t *testing.T) {
426
427 type test struct {
428 u, v string
429 wasSquare int
430 r string
431 }
432 var tests = []test{
433
434
435
436 {
437 "0000000000000000000000000000000000000000000000000000000000000000",
438 "0000000000000000000000000000000000000000000000000000000000000000",
439 1, "0000000000000000000000000000000000000000000000000000000000000000",
440 },
441
442 {
443 "0000000000000000000000000000000000000000000000000000000000000000",
444 "0100000000000000000000000000000000000000000000000000000000000000",
445 1, "0000000000000000000000000000000000000000000000000000000000000000",
446 },
447
448 {
449 "0100000000000000000000000000000000000000000000000000000000000000",
450 "0000000000000000000000000000000000000000000000000000000000000000",
451 0, "0000000000000000000000000000000000000000000000000000000000000000",
452 },
453
454 {
455 "0200000000000000000000000000000000000000000000000000000000000000",
456 "0100000000000000000000000000000000000000000000000000000000000000",
457 0, "3c5ff1b5d8e4113b871bd052f9e7bcd0582804c266ffb2d4f4203eb07fdb7c54",
458 },
459
460 {
461 "0400000000000000000000000000000000000000000000000000000000000000",
462 "0100000000000000000000000000000000000000000000000000000000000000",
463 1, "0200000000000000000000000000000000000000000000000000000000000000",
464 },
465
466 {
467 "0100000000000000000000000000000000000000000000000000000000000000",
468 "0400000000000000000000000000000000000000000000000000000000000000",
469 1, "f6ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3f",
470 },
471 }
472
473 for i, tt := range tests {
474 u := new(Element).SetBytes(decodeHex(tt.u))
475 v := new(Element).SetBytes(decodeHex(tt.v))
476 want := new(Element).SetBytes(decodeHex(tt.r))
477 got, wasSquare := new(Element).SqrtRatio(u, v)
478 if got.Equal(want) == 0 || wasSquare != tt.wasSquare {
479 t.Errorf("%d: got (%v, %v), want (%v, %v)", i, got, wasSquare, want, tt.wasSquare)
480 }
481 }
482 }
483
484 func TestCarryPropagate(t *testing.T) {
485 asmLikeGeneric := func(a [5]uint64) bool {
486 t1 := &Element{a[0], a[1], a[2], a[3], a[4]}
487 t2 := &Element{a[0], a[1], a[2], a[3], a[4]}
488
489 t1.carryPropagate()
490 t2.carryPropagateGeneric()
491
492 if *t1 != *t2 {
493 t.Logf("got: %#v,\nexpected: %#v", t1, t2)
494 }
495
496 return *t1 == *t2 && isInBounds(t2)
497 }
498
499 if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil {
500 t.Error(err)
501 }
502
503 if !asmLikeGeneric([5]uint64{0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}) {
504 t.Errorf("failed for {0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}")
505 }
506 }
507
508 func TestFeSquare(t *testing.T) {
509 asmLikeGeneric := func(a Element) bool {
510 t1 := a
511 t2 := a
512
513 feSquareGeneric(&t1, &t1)
514 feSquare(&t2, &t2)
515
516 if t1 != t2 {
517 t.Logf("got: %#v,\nexpected: %#v", t1, t2)
518 }
519
520 return t1 == t2 && isInBounds(&t2)
521 }
522
523 if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil {
524 t.Error(err)
525 }
526 }
527
528 func TestFeMul(t *testing.T) {
529 asmLikeGeneric := func(a, b Element) bool {
530 a1 := a
531 a2 := a
532 b1 := b
533 b2 := b
534
535 feMulGeneric(&a1, &a1, &b1)
536 feMul(&a2, &a2, &b2)
537
538 if a1 != a2 || b1 != b2 {
539 t.Logf("got: %#v,\nexpected: %#v", a1, a2)
540 t.Logf("got: %#v,\nexpected: %#v", b1, b2)
541 }
542
543 return a1 == a2 && isInBounds(&a2) &&
544 b1 == b2 && isInBounds(&b2)
545 }
546
547 if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil {
548 t.Error(err)
549 }
550 }
551
552 func decodeHex(s string) []byte {
553 b, err := hex.DecodeString(s)
554 if err != nil {
555 panic(err)
556 }
557 return b
558 }
559
View as plain text