Source file src/crypto/ed25519/internal/edwards25519/field/fe_test.go

     1  // Copyright (c) 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package field
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"encoding/hex"
    11  	"io"
    12  	"math/big"
    13  	"math/bits"
    14  	mathrand "math/rand"
    15  	"reflect"
    16  	"testing"
    17  	"testing/quick"
    18  )
    19  
    20  func (v Element) String() string {
    21  	return hex.EncodeToString(v.Bytes())
    22  }
    23  
    24  // quickCheckConfig1024 will make each quickcheck test run (1024 * -quickchecks)
    25  // times. The default value of -quickchecks is 100.
    26  var quickCheckConfig1024 = &quick.Config{MaxCountScale: 1 << 10}
    27  
    28  func generateFieldElement(rand *mathrand.Rand) Element {
    29  	const maskLow52Bits = (1 << 52) - 1
    30  	return Element{
    31  		rand.Uint64() & maskLow52Bits,
    32  		rand.Uint64() & maskLow52Bits,
    33  		rand.Uint64() & maskLow52Bits,
    34  		rand.Uint64() & maskLow52Bits,
    35  		rand.Uint64() & maskLow52Bits,
    36  	}
    37  }
    38  
    39  // weirdLimbs can be combined to generate a range of edge-case field elements.
    40  // 0 and -1 are intentionally more weighted, as they combine well.
    41  var (
    42  	weirdLimbs51 = []uint64{
    43  		0, 0, 0, 0,
    44  		1,
    45  		19 - 1,
    46  		19,
    47  		0x2aaaaaaaaaaaa,
    48  		0x5555555555555,
    49  		(1 << 51) - 20,
    50  		(1 << 51) - 19,
    51  		(1 << 51) - 1, (1 << 51) - 1,
    52  		(1 << 51) - 1, (1 << 51) - 1,
    53  	}
    54  	weirdLimbs52 = []uint64{
    55  		0, 0, 0, 0, 0, 0,
    56  		1,
    57  		19 - 1,
    58  		19,
    59  		0x2aaaaaaaaaaaa,
    60  		0x5555555555555,
    61  		(1 << 51) - 20,
    62  		(1 << 51) - 19,
    63  		(1 << 51) - 1, (1 << 51) - 1,
    64  		(1 << 51) - 1, (1 << 51) - 1,
    65  		(1 << 51) - 1, (1 << 51) - 1,
    66  		1 << 51,
    67  		(1 << 51) + 1,
    68  		(1 << 52) - 19,
    69  		(1 << 52) - 1,
    70  	}
    71  )
    72  
    73  func generateWeirdFieldElement(rand *mathrand.Rand) Element {
    74  	return Element{
    75  		weirdLimbs52[rand.Intn(len(weirdLimbs52))],
    76  		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
    77  		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
    78  		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
    79  		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
    80  	}
    81  }
    82  
    83  func (Element) Generate(rand *mathrand.Rand, size int) reflect.Value {
    84  	if rand.Intn(2) == 0 {
    85  		return reflect.ValueOf(generateWeirdFieldElement(rand))
    86  	}
    87  	return reflect.ValueOf(generateFieldElement(rand))
    88  }
    89  
    90  // isInBounds returns whether the element is within the expected bit size bounds
    91  // after a light reduction.
    92  func isInBounds(x *Element) bool {
    93  	return bits.Len64(x.l0) <= 52 &&
    94  		bits.Len64(x.l1) <= 52 &&
    95  		bits.Len64(x.l2) <= 52 &&
    96  		bits.Len64(x.l3) <= 52 &&
    97  		bits.Len64(x.l4) <= 52
    98  }
    99  
   100  func TestMultiplyDistributesOverAdd(t *testing.T) {
   101  	multiplyDistributesOverAdd := func(x, y, z Element) bool {
   102  		// Compute t1 = (x+y)*z
   103  		t1 := new(Element)
   104  		t1.Add(&x, &y)
   105  		t1.Multiply(t1, &z)
   106  
   107  		// Compute t2 = x*z + y*z
   108  		t2 := new(Element)
   109  		t3 := new(Element)
   110  		t2.Multiply(&x, &z)
   111  		t3.Multiply(&y, &z)
   112  		t2.Add(t2, t3)
   113  
   114  		return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
   115  	}
   116  
   117  	if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig1024); err != nil {
   118  		t.Error(err)
   119  	}
   120  }
   121  
   122  func TestMul64to128(t *testing.T) {
   123  	a := uint64(5)
   124  	b := uint64(5)
   125  	r := mul64(a, b)
   126  	if r.lo != 0x19 || r.hi != 0 {
   127  		t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
   128  	}
   129  
   130  	a = uint64(18014398509481983) // 2^54 - 1
   131  	b = uint64(18014398509481983) // 2^54 - 1
   132  	r = mul64(a, b)
   133  	if r.lo != 0xff80000000000001 || r.hi != 0xfffffffffff {
   134  		t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
   135  	}
   136  
   137  	a = uint64(1125899906842661)
   138  	b = uint64(2097155)
   139  	r = mul64(a, b)
   140  	r = addMul64(r, a, b)
   141  	r = addMul64(r, a, b)
   142  	r = addMul64(r, a, b)
   143  	r = addMul64(r, a, b)
   144  	if r.lo != 16888498990613035 || r.hi != 640 {
   145  		t.Errorf("wrong answer: %d + %d*(2**64)", r.lo, r.hi)
   146  	}
   147  }
   148  
   149  func TestSetBytesRoundTrip(t *testing.T) {
   150  	f1 := func(in [32]byte, fe Element) bool {
   151  		fe.SetBytes(in[:])
   152  
   153  		// Mask the most significant bit as it's ignored by SetBytes. (Now
   154  		// instead of earlier so we check the masking in SetBytes is working.)
   155  		in[len(in)-1] &= (1 << 7) - 1
   156  
   157  		return bytes.Equal(in[:], fe.Bytes()) && isInBounds(&fe)
   158  	}
   159  	if err := quick.Check(f1, nil); err != nil {
   160  		t.Errorf("failed bytes->FE->bytes round-trip: %v", err)
   161  	}
   162  
   163  	f2 := func(fe, r Element) bool {
   164  		r.SetBytes(fe.Bytes())
   165  
   166  		// Intentionally not using Equal not to go through Bytes again.
   167  		// Calling reduce because both Generate and SetBytes can produce
   168  		// non-canonical representations.
   169  		fe.reduce()
   170  		r.reduce()
   171  		return fe == r
   172  	}
   173  	if err := quick.Check(f2, nil); err != nil {
   174  		t.Errorf("failed FE->bytes->FE round-trip: %v", err)
   175  	}
   176  
   177  	// Check some fixed vectors from dalek
   178  	type feRTTest struct {
   179  		fe Element
   180  		b  []byte
   181  	}
   182  	var tests = []feRTTest{
   183  		{
   184  			fe: Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676},
   185  			b:  []byte{74, 209, 69, 197, 70, 70, 161, 222, 56, 226, 229, 19, 112, 60, 25, 92, 187, 74, 222, 56, 50, 153, 51, 233, 40, 74, 57, 6, 160, 185, 213, 31},
   186  		},
   187  		{
   188  			fe: Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972},
   189  			b:  []byte{199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122},
   190  		},
   191  	}
   192  
   193  	for _, tt := range tests {
   194  		b := tt.fe.Bytes()
   195  		if !bytes.Equal(b, tt.b) || new(Element).SetBytes(tt.b).Equal(&tt.fe) != 1 {
   196  			t.Errorf("Failed fixed roundtrip: %v", tt)
   197  		}
   198  	}
   199  }
   200  
   201  func swapEndianness(buf []byte) []byte {
   202  	for i := 0; i < len(buf)/2; i++ {
   203  		buf[i], buf[len(buf)-i-1] = buf[len(buf)-i-1], buf[i]
   204  	}
   205  	return buf
   206  }
   207  
   208  func TestBytesBigEquivalence(t *testing.T) {
   209  	f1 := func(in [32]byte, fe, fe1 Element) bool {
   210  		fe.SetBytes(in[:])
   211  
   212  		in[len(in)-1] &= (1 << 7) - 1 // mask the most significant bit
   213  		b := new(big.Int).SetBytes(swapEndianness(in[:]))
   214  		fe1.fromBig(b)
   215  
   216  		if fe != fe1 {
   217  			return false
   218  		}
   219  
   220  		buf := make([]byte, 32) // pad with zeroes
   221  		copy(buf, swapEndianness(fe1.toBig().Bytes()))
   222  
   223  		return bytes.Equal(fe.Bytes(), buf) && isInBounds(&fe) && isInBounds(&fe1)
   224  	}
   225  	if err := quick.Check(f1, nil); err != nil {
   226  		t.Error(err)
   227  	}
   228  }
   229  
   230  // fromBig sets v = n, and returns v. The bit length of n must not exceed 256.
   231  func (v *Element) fromBig(n *big.Int) *Element {
   232  	if n.BitLen() > 32*8 {
   233  		panic("edwards25519: invalid field element input size")
   234  	}
   235  
   236  	buf := make([]byte, 0, 32)
   237  	for _, word := range n.Bits() {
   238  		for i := 0; i < bits.UintSize; i += 8 {
   239  			if len(buf) >= cap(buf) {
   240  				break
   241  			}
   242  			buf = append(buf, byte(word))
   243  			word >>= 8
   244  		}
   245  	}
   246  
   247  	return v.SetBytes(buf[:32])
   248  }
   249  
   250  func (v *Element) fromDecimal(s string) *Element {
   251  	n, ok := new(big.Int).SetString(s, 10)
   252  	if !ok {
   253  		panic("not a valid decimal: " + s)
   254  	}
   255  	return v.fromBig(n)
   256  }
   257  
   258  // toBig returns v as a big.Int.
   259  func (v *Element) toBig() *big.Int {
   260  	buf := v.Bytes()
   261  
   262  	words := make([]big.Word, 32*8/bits.UintSize)
   263  	for n := range words {
   264  		for i := 0; i < bits.UintSize; i += 8 {
   265  			if len(buf) == 0 {
   266  				break
   267  			}
   268  			words[n] |= big.Word(buf[0]) << big.Word(i)
   269  			buf = buf[1:]
   270  		}
   271  	}
   272  
   273  	return new(big.Int).SetBits(words)
   274  }
   275  
   276  func TestDecimalConstants(t *testing.T) {
   277  	sqrtM1String := "19681161376707505956807079304988542015446066515923890162744021073123829784752"
   278  	if exp := new(Element).fromDecimal(sqrtM1String); sqrtM1.Equal(exp) != 1 {
   279  		t.Errorf("sqrtM1 is %v, expected %v", sqrtM1, exp)
   280  	}
   281  	// d is in the parent package, and we don't want to expose d or fromDecimal.
   282  	// dString := "37095705934669439343138083508754565189542113879843219016388785533085940283555"
   283  	// if exp := new(Element).fromDecimal(dString); d.Equal(exp) != 1 {
   284  	// 	t.Errorf("d is %v, expected %v", d, exp)
   285  	// }
   286  }
   287  
   288  func TestSetBytesRoundTripEdgeCases(t *testing.T) {
   289  	// TODO: values close to 0, close to 2^255-19, between 2^255-19 and 2^255-1,
   290  	// and between 2^255 and 2^256-1. Test both the documented SetBytes
   291  	// behavior, and that Bytes reduces them.
   292  }
   293  
   294  // Tests self-consistency between Multiply and Square.
   295  func TestConsistency(t *testing.T) {
   296  	var x Element
   297  	var x2, x2sq Element
   298  
   299  	x = Element{1, 1, 1, 1, 1}
   300  	x2.Multiply(&x, &x)
   301  	x2sq.Square(&x)
   302  
   303  	if x2 != x2sq {
   304  		t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
   305  	}
   306  
   307  	var bytes [32]byte
   308  
   309  	_, err := io.ReadFull(rand.Reader, bytes[:])
   310  	if err != nil {
   311  		t.Fatal(err)
   312  	}
   313  	x.SetBytes(bytes[:])
   314  
   315  	x2.Multiply(&x, &x)
   316  	x2sq.Square(&x)
   317  
   318  	if x2 != x2sq {
   319  		t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
   320  	}
   321  }
   322  
   323  func TestEqual(t *testing.T) {
   324  	x := Element{1, 1, 1, 1, 1}
   325  	y := Element{5, 4, 3, 2, 1}
   326  
   327  	eq := x.Equal(&x)
   328  	if eq != 1 {
   329  		t.Errorf("wrong about equality")
   330  	}
   331  
   332  	eq = x.Equal(&y)
   333  	if eq != 0 {
   334  		t.Errorf("wrong about inequality")
   335  	}
   336  }
   337  
   338  func TestInvert(t *testing.T) {
   339  	x := Element{1, 1, 1, 1, 1}
   340  	one := Element{1, 0, 0, 0, 0}
   341  	var xinv, r Element
   342  
   343  	xinv.Invert(&x)
   344  	r.Multiply(&x, &xinv)
   345  	r.reduce()
   346  
   347  	if one != r {
   348  		t.Errorf("inversion identity failed, got: %x", r)
   349  	}
   350  
   351  	var bytes [32]byte
   352  
   353  	_, err := io.ReadFull(rand.Reader, bytes[:])
   354  	if err != nil {
   355  		t.Fatal(err)
   356  	}
   357  	x.SetBytes(bytes[:])
   358  
   359  	xinv.Invert(&x)
   360  	r.Multiply(&x, &xinv)
   361  	r.reduce()
   362  
   363  	if one != r {
   364  		t.Errorf("random inversion identity failed, got: %x for field element %x", r, x)
   365  	}
   366  
   367  	zero := Element{}
   368  	x.Set(&zero)
   369  	if xx := xinv.Invert(&x); xx != &xinv {
   370  		t.Errorf("inverting zero did not return the receiver")
   371  	} else if xinv.Equal(&zero) != 1 {
   372  		t.Errorf("inverting zero did not return zero")
   373  	}
   374  }
   375  
   376  func TestSelectSwap(t *testing.T) {
   377  	a := Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}
   378  	b := Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}
   379  
   380  	var c, d Element
   381  
   382  	c.Select(&a, &b, 1)
   383  	d.Select(&a, &b, 0)
   384  
   385  	if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
   386  		t.Errorf("Select failed")
   387  	}
   388  
   389  	c.Swap(&d, 0)
   390  
   391  	if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
   392  		t.Errorf("Swap failed")
   393  	}
   394  
   395  	c.Swap(&d, 1)
   396  
   397  	if c.Equal(&b) != 1 || d.Equal(&a) != 1 {
   398  		t.Errorf("Swap failed")
   399  	}
   400  }
   401  
   402  func TestMult32(t *testing.T) {
   403  	mult32EquivalentToMul := func(x Element, y uint32) bool {
   404  		t1 := new(Element)
   405  		for i := 0; i < 100; i++ {
   406  			t1.Mult32(&x, y)
   407  		}
   408  
   409  		ty := new(Element)
   410  		ty.l0 = uint64(y)
   411  
   412  		t2 := new(Element)
   413  		for i := 0; i < 100; i++ {
   414  			t2.Multiply(&x, ty)
   415  		}
   416  
   417  		return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
   418  	}
   419  
   420  	if err := quick.Check(mult32EquivalentToMul, quickCheckConfig1024); err != nil {
   421  		t.Error(err)
   422  	}
   423  }
   424  
   425  func TestSqrtRatio(t *testing.T) {
   426  	// From draft-irtf-cfrg-ristretto255-decaf448-00, Appendix A.4.
   427  	type test struct {
   428  		u, v      string
   429  		wasSquare int
   430  		r         string
   431  	}
   432  	var tests = []test{
   433  		// If u is 0, the function is defined to return (0, TRUE), even if v
   434  		// is zero. Note that where used in this package, the denominator v
   435  		// is never zero.
   436  		{
   437  			"0000000000000000000000000000000000000000000000000000000000000000",
   438  			"0000000000000000000000000000000000000000000000000000000000000000",
   439  			1, "0000000000000000000000000000000000000000000000000000000000000000",
   440  		},
   441  		// 0/1 == 0²
   442  		{
   443  			"0000000000000000000000000000000000000000000000000000000000000000",
   444  			"0100000000000000000000000000000000000000000000000000000000000000",
   445  			1, "0000000000000000000000000000000000000000000000000000000000000000",
   446  		},
   447  		// If u is non-zero and v is zero, defined to return (0, FALSE).
   448  		{
   449  			"0100000000000000000000000000000000000000000000000000000000000000",
   450  			"0000000000000000000000000000000000000000000000000000000000000000",
   451  			0, "0000000000000000000000000000000000000000000000000000000000000000",
   452  		},
   453  		// 2/1 is not square in this field.
   454  		{
   455  			"0200000000000000000000000000000000000000000000000000000000000000",
   456  			"0100000000000000000000000000000000000000000000000000000000000000",
   457  			0, "3c5ff1b5d8e4113b871bd052f9e7bcd0582804c266ffb2d4f4203eb07fdb7c54",
   458  		},
   459  		// 4/1 == 2²
   460  		{
   461  			"0400000000000000000000000000000000000000000000000000000000000000",
   462  			"0100000000000000000000000000000000000000000000000000000000000000",
   463  			1, "0200000000000000000000000000000000000000000000000000000000000000",
   464  		},
   465  		// 1/4 == (2⁻¹)² == (2^(p-2))² per Euler's theorem
   466  		{
   467  			"0100000000000000000000000000000000000000000000000000000000000000",
   468  			"0400000000000000000000000000000000000000000000000000000000000000",
   469  			1, "f6ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3f",
   470  		},
   471  	}
   472  
   473  	for i, tt := range tests {
   474  		u := new(Element).SetBytes(decodeHex(tt.u))
   475  		v := new(Element).SetBytes(decodeHex(tt.v))
   476  		want := new(Element).SetBytes(decodeHex(tt.r))
   477  		got, wasSquare := new(Element).SqrtRatio(u, v)
   478  		if got.Equal(want) == 0 || wasSquare != tt.wasSquare {
   479  			t.Errorf("%d: got (%v, %v), want (%v, %v)", i, got, wasSquare, want, tt.wasSquare)
   480  		}
   481  	}
   482  }
   483  
   484  func TestCarryPropagate(t *testing.T) {
   485  	asmLikeGeneric := func(a [5]uint64) bool {
   486  		t1 := &Element{a[0], a[1], a[2], a[3], a[4]}
   487  		t2 := &Element{a[0], a[1], a[2], a[3], a[4]}
   488  
   489  		t1.carryPropagate()
   490  		t2.carryPropagateGeneric()
   491  
   492  		if *t1 != *t2 {
   493  			t.Logf("got: %#v,\nexpected: %#v", t1, t2)
   494  		}
   495  
   496  		return *t1 == *t2 && isInBounds(t2)
   497  	}
   498  
   499  	if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil {
   500  		t.Error(err)
   501  	}
   502  
   503  	if !asmLikeGeneric([5]uint64{0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}) {
   504  		t.Errorf("failed for {0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}")
   505  	}
   506  }
   507  
   508  func TestFeSquare(t *testing.T) {
   509  	asmLikeGeneric := func(a Element) bool {
   510  		t1 := a
   511  		t2 := a
   512  
   513  		feSquareGeneric(&t1, &t1)
   514  		feSquare(&t2, &t2)
   515  
   516  		if t1 != t2 {
   517  			t.Logf("got: %#v,\nexpected: %#v", t1, t2)
   518  		}
   519  
   520  		return t1 == t2 && isInBounds(&t2)
   521  	}
   522  
   523  	if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil {
   524  		t.Error(err)
   525  	}
   526  }
   527  
   528  func TestFeMul(t *testing.T) {
   529  	asmLikeGeneric := func(a, b Element) bool {
   530  		a1 := a
   531  		a2 := a
   532  		b1 := b
   533  		b2 := b
   534  
   535  		feMulGeneric(&a1, &a1, &b1)
   536  		feMul(&a2, &a2, &b2)
   537  
   538  		if a1 != a2 || b1 != b2 {
   539  			t.Logf("got: %#v,\nexpected: %#v", a1, a2)
   540  			t.Logf("got: %#v,\nexpected: %#v", b1, b2)
   541  		}
   542  
   543  		return a1 == a2 && isInBounds(&a2) &&
   544  			b1 == b2 && isInBounds(&b2)
   545  	}
   546  
   547  	if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil {
   548  		t.Error(err)
   549  	}
   550  }
   551  
   552  func decodeHex(s string) []byte {
   553  	b, err := hex.DecodeString(s)
   554  	if err != nil {
   555  		panic(err)
   556  	}
   557  	return b
   558  }
   559  

View as plain text