Skip to content

Commit

Permalink
API alignment + ERROR fix (#82)
Browse files Browse the repository at this point in the history
This PR does a couple of things

    Aligns the API a bit more closely to big.Int. Primarily, by making NewInt(uint64) take an input parameter.
    Changes AddOverflow(x, y *Int) bool -> AddOverflow(x, y *Int) (*Int, bool)
    Changes SubOverflow(x, y *Int) bool -> SubOverflow(x, y *Int) (*Int, bool)
    Adds AddUint64

Important This PR also fixes an error in SubUint64. Before this PR, if the receiver was not identical to the first argument, it would return wrong numbers in many cases, due to exiting early if the carry was zero.
  • Loading branch information
holiman authored Apr 26, 2021
1 parent 3c4134f commit b323bdc
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 51 deletions.
6 changes: 3 additions & 3 deletions benchmarks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ func benchmark_Lsh_Bit(n uint, bench *testing.B) {
f2, _ := FromBig(original)
bench.ResetTimer()
for i := 0; i < bench.N; i++ {
f1 := NewInt()
f1 := new(Int)
f1.Lsh(f2, n)
}
}
Expand Down Expand Up @@ -431,7 +431,7 @@ func benchmark_Rsh_Bit(n uint, bench *testing.B) {
f2, _ := FromBig(original)
bench.ResetTimer()
for i := 0; i < bench.N; i++ {
f1 := NewInt()
f1 := new(Int)
f1.Rsh(f2, n)
}
}
Expand Down Expand Up @@ -686,7 +686,7 @@ func benchmark_SdivLarge_Bit(bench *testing.B) {

bench.ResetTimer()
for i := 0; i < bench.N; i++ {
f := NewInt()
f := new(Int)
f.SDiv(fa, fb)
}
}
Expand Down
4 changes: 4 additions & 0 deletions conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,10 @@ func bigEndianUint56(b []byte) uint64 {
// EncodeRLP implements the rlp.Encoder interface from go-ethereum
// and writes the RLP encoding of z to w.
func (z *Int) EncodeRLP(w io.Writer) error {
if z == nil {
_, err := w.Write([]byte{0x80})
return err
}
nBits := z.BitLen()
if nBits == 0 {
_, err := w.Write([]byte{0x80})
Expand Down
12 changes: 6 additions & 6 deletions conversion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,22 +184,22 @@ func TestSetBytes(t *testing.T) {
for i := 0; i < 35; i++ {
buf := hex2Bytes("aaaa12131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f3031bbbb")
exp, _ := FromBig(new(big.Int).SetBytes(buf[0:i]))
z := NewInt().SetAllOne().SetBytes(buf[0:i])
z := new(Int).SetAllOne().SetBytes(buf[0:i])
if !z.Eq(exp) {
t.Errorf("testcase %d: exp %x, got %x", i, exp, z)
}
}
// nil check
exp, _ := FromBig(new(big.Int).SetBytes(nil))
z := NewInt().SetAllOne().SetBytes(nil)
z := new(Int).SetAllOne().SetBytes(nil)
if !z.Eq(exp) {
t.Errorf("nil-test : exp %x, got %x", exp, z)
}
}

func BenchmarkSetBytes(b *testing.B) {

val := NewInt()
val := new(Int)
bytearr := hex2Bytes("12131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f3031")
b.Run("generic", func(b *testing.B) {
b.ReportAllocs()
Expand Down Expand Up @@ -544,7 +544,7 @@ func TestRlpEncode(t *testing.T) {
{"4000000000000000000000000000000000000000000000000000000000000000", "a04000000000000000000000000000000000000000000000000000000000000000"},
{"8000000000000000000000000000000000000000000000000000000000000000", "a08000000000000000000000000000000000000000000000000000000000000000"},
} {
z := NewInt().SetBytes(hex2Bytes(tt.val))
z := new(Int).SetBytes(hex2Bytes(tt.val))
var b bytes.Buffer
w := bufio.NewWriter(&b)
if err := z.EncodeRLP(w); err != nil {
Expand All @@ -565,7 +565,7 @@ func (n2 *nilWriter) Write(p []byte) (n int, err error) {

// BenchmarkRLPEncoding writes 255 Ints ranging in bitsize from 0-255 in each op
func BenchmarkRLPEncoding(b *testing.B) {
z := NewInt()
z := new(Int)
devnull := &nilWriter{}
b.ReportAllocs()
b.ResetTimer()
Expand Down Expand Up @@ -742,7 +742,7 @@ func TestEnDecode(t *testing.T) {
if dec.Cmp(&intSample) != 0 {
t.Fatalf("test %d #6, got %v, exp %v", i, dec, intSample)
}
dec = NewInt()
dec = new(Int)
if err := dec.UnmarshalText([]byte(exp)); err != nil {
t.Fatalf("test %d #7, err: %v", i, err)
}
Expand Down
2 changes: 1 addition & 1 deletion div.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ func udivrem2by1(uh, ul, d, reciprocal uint64) (quot, rem uint64) {
}

return qh, r
}
}
54 changes: 30 additions & 24 deletions uint256.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ import (
// so that Int[3] is the most significant, and Int[0] is the least significant
type Int [4]uint64

// NewInt returns a new zero-initialized Int.
func NewInt() *Int {
return &Int{}
// NewInt returns a new initialized Int.
func NewInt(val uint64) *Int {
z := &Int{}
z.SetUint64(val)
return z
}

// SetBytes interprets buf as the bytes of a big-endian unsigned
Expand Down Expand Up @@ -180,14 +182,14 @@ func (z *Int) Add(x, y *Int) *Int {
return z
}

// AddOverflow sets z to the sum x+y, and returns whether overflow occurred
func (z *Int) AddOverflow(x, y *Int) bool {
// AddOverflow sets z to the sum x+y, and returns z and whether overflow occurred
func (z *Int) AddOverflow(x, y *Int) (*Int, bool) {
var carry uint64
z[0], carry = bits.Add64(x[0], y[0], 0)
z[1], carry = bits.Add64(x[1], y[1], carry)
z[2], carry = bits.Add64(x[2], y[2], carry)
z[3], carry = bits.Add64(x[3], y[3], carry)
return carry != 0
return z, carry != 0
}

// AddMod sets z to the sum ( x+y ) mod m, and returns z.
Expand All @@ -199,7 +201,7 @@ func (z *Int) AddMod(x, y, m *Int) *Int {
if z == m { // z is an alias for m // TODO: Understand why needed and add tests for all "division" methods.
m = m.Clone()
}
if overflow := z.AddOverflow(x, y); overflow {
if _, overflow := z.AddOverflow(x, y); overflow {
sum := [5]uint64{z[0], z[1], z[2], z[3], 1}
var quot [5]uint64
rem := udivrem(quot[:], sum[:], m)
Expand All @@ -208,6 +210,17 @@ func (z *Int) AddMod(x, y, m *Int) *Int {
return z.Mod(z, m)
}

// AddUint64 sets z to x + y, where y is a uint64, and returns z
func (z *Int) AddUint64(x *Int, y uint64) *Int {
var carry uint64

z[0], carry = bits.Add64(x[0], y, 0)
z[1], carry = bits.Add64(x[1], 0, carry)
z[2], carry = bits.Add64(x[2], 0, carry)
z[3], _ = bits.Add64(x[3], 0, carry)
return z
}

// PaddedBytes encodes a Int as a 0-padded byte slice. The length
// of the slice is at least n bytes.
// Example, z =1, n = 20 => [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1]
Expand All @@ -223,28 +236,21 @@ func (z *Int) PaddedBytes(n int) []byte {
// SubUint64 set z to the difference x - y, where y is a uint64, and returns z
func (z *Int) SubUint64(x *Int, y uint64) *Int {
var carry uint64

if z[0], carry = bits.Sub64(x[0], y, carry); carry == 0 {
return z
}
if z[1], carry = bits.Sub64(x[1], 0, carry); carry == 0 {
return z
}
if z[2], carry = bits.Sub64(x[2], 0, carry); carry == 0 {
return z
}
z[3]--
z[0], carry = bits.Sub64(x[0], y, carry)
z[1], carry = bits.Sub64(x[1], 0, carry)
z[2], carry = bits.Sub64(x[2], 0, carry)
z[3], _ = bits.Sub64(x[3], 0, carry)
return z
}

// SubOverflow sets z to the difference x-y and returns true if the operation underflowed
func (z *Int) SubOverflow(x, y *Int) bool {
// SubOverflow sets z to the difference x-y and returns z and true if the operation underflowed
func (z *Int) SubOverflow(x, y *Int) (*Int, bool) {
var carry uint64
z[0], carry = bits.Sub64(x[0], y[0], 0)
z[1], carry = bits.Sub64(x[1], y[1], carry)
z[2], carry = bits.Sub64(x[2], y[2], carry)
z[3], carry = bits.Sub64(x[3], y[3], carry)
return carry != 0
return z, carry != 0
}

// Sub sets z to the difference x-y
Expand Down Expand Up @@ -331,11 +337,11 @@ func (z *Int) Mul(x, y *Int) *Int {
return z.Set(&res)
}

// MulOverflow sets z to the product x*y, and returns whether overflow occurred
func (z *Int) MulOverflow(x, y *Int) bool {
// MulOverflow sets z to the product x*y, and returns z and whether overflow occurred
func (z *Int) MulOverflow(x, y *Int) (*Int, bool) {
p := umul(x, y)
copy(z[:], p[:4])
return (p[4] | p[5] | p[6] | p[7]) != 0
return z, (p[4] | p[5] | p[6] | p[7]) != 0
}

func (z *Int) squared() {
Expand Down
60 changes: 43 additions & 17 deletions uint256_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ func randNums() (*big.Int, *Int, error) {
err := checkOverflow(b, f, overflow)
return b, f, err
}

func randHighNums() (*big.Int, *Int, error) {
//How many bits? 0-256
nbits := int64(256)
Expand Down Expand Up @@ -188,6 +189,7 @@ func testRandomOp(t *testing.T, nativeFunc func(a, b, c *Int), bigintFunc func(a
}
}
}

func TestRandomSubOverflow(t *testing.T) {
for i := 0; i < 10000; i++ {
b, f1, err := randNums()
Expand All @@ -199,7 +201,7 @@ func TestRandomSubOverflow(t *testing.T) {
t.Fatal(err)
}
f1a, f2a := f1.Clone(), f2.Clone()
overflow := f1.SubOverflow(f1, f2)
_, overflow := f1.SubOverflow(f1, f2)
b.Sub(b, b2)
if err := checkUnderflow(b, f1, overflow); err != nil {
t.Fatal(err)
Expand All @@ -209,6 +211,7 @@ func TestRandomSubOverflow(t *testing.T) {
}
}
}

func TestRandomSub(t *testing.T) {
testRandomOp(t,
func(f1, f2, f3 *Int) {
Expand All @@ -230,6 +233,7 @@ func TestRandomAdd(t *testing.T) {
},
)
}

func TestRandomMul(t *testing.T) {

testRandomOp(t,
Expand All @@ -241,6 +245,7 @@ func TestRandomMul(t *testing.T) {
},
)
}

func TestRandomMulOverflow(t *testing.T) {
for i := 0; i < 10000; i++ {
b, f1, err := randNums()
Expand All @@ -252,7 +257,7 @@ func TestRandomMulOverflow(t *testing.T) {
t.Fatal(err)
}
f1a, f2a := f1.Clone(), f2.Clone()
overflow := f1.MulOverflow(f1, f2)
_, overflow := f1.MulOverflow(f1, f2)
b.Mul(b, b2)
if err := checkOverflow(b, f1, overflow); err != nil {
t.Fatal(err)
Expand All @@ -262,6 +267,7 @@ func TestRandomMulOverflow(t *testing.T) {
}
}
}

func TestRandomSquare(t *testing.T) {
testRandomOp(t,
func(f1, f2, f3 *Int) {
Expand All @@ -272,6 +278,7 @@ func TestRandomSquare(t *testing.T) {
},
)
}

func TestRandomDiv(t *testing.T) {
testRandomOp(t,
func(f1, f2, f3 *Int) {
Expand Down Expand Up @@ -301,6 +308,7 @@ func TestRandomMod(t *testing.T) {
},
)
}

func TestRandomSMod(t *testing.T) {
testRandomOp(t,
func(f1, f2, f3 *Int) {
Expand Down Expand Up @@ -484,21 +492,21 @@ func TestSRsh(t *testing.T) {

func TestByte(t *testing.T) {
z := new(Int).SetBytes(hex2Bytes("ABCDEF09080706050403020100000000000000000000000000000000000000ef"))
actual := z.Byte(NewInt().SetUint64(0))
actual := z.Byte(NewInt(0))
expected := new(Int).SetBytes(hex2Bytes("00000000000000000000000000000000000000000000000000000000000000ab"))
if !actual.Eq(expected) {
t.Fatalf("Expected %x, got %x", expected, actual)
}

z = new(Int).SetBytes(hex2Bytes("ABCDEF09080706050403020100000000000000000000000000000000000000ef"))
actual = z.Byte(NewInt().SetUint64(31))
actual = z.Byte(NewInt(31))
expected = new(Int).SetBytes(hex2Bytes("00000000000000000000000000000000000000000000000000000000000000ef"))
if !actual.Eq(expected) {
t.Fatalf("Expected %x, got %x", expected, actual)
}

z = new(Int).SetBytes(hex2Bytes("ABCDEF09080706050403020100000000000000000000000000000000000000ef"))
actual = z.Byte(NewInt().SetUint64(32))
actual = z.Byte(NewInt(32))
expected = new(Int).SetBytes(hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"))
if !actual.Eq(expected) {
t.Fatalf("Expected %x, got %x", expected, actual)
Expand Down Expand Up @@ -559,7 +567,7 @@ func TestSignExtend(t *testing.T) {
}
}

func TestSubUint64(t *testing.T) {
func TestAddSubUint64(t *testing.T) {
type testCase struct {
arg string
n uint64
Expand All @@ -571,20 +579,38 @@ func TestSubUint64(t *testing.T) {
{"1", 3},
{"0x10000000000000000", 1},
{"0x100000000000000000000000000000000", 1},
{"0", 0xffffffffffffffff},
{"1", 0xffffffff},
{"0xffffffffffffffff", 1},
{"0xffffffffffffffff", 0xffffffffffffffff},
{"0x10000000000000000", 1},
{"0xfffffffffffffffffffffffffffffffff", 1},
{"0xfffffffffffffffffffffffffffffffff", 2},
}

for i := 0; i < len(testCases); i++ {
tc := &testCases[i]
bigArg, _ := new(big.Int).SetString(tc.arg, 0)
arg, _ := FromBig(bigArg)
expected, _ := FromBig(U256(new(big.Int).Sub(bigArg, new(big.Int).SetUint64(tc.n))))
result := new(Int).SubUint64(arg, tc.n)

if !result.Eq(expected) {
t.Logf("args: %s, %d\n", tc.arg, tc.n)
t.Logf("exp : %x\n", expected)
t.Logf("got : %x\n\n", result)
t.Fail()
{ // SubUint64
want, _ := FromBig(U256(new(big.Int).Sub(bigArg, new(big.Int).SetUint64(tc.n))))
have := new(Int).SetAllOne().SubUint64(arg, tc.n)
if !have.Eq(want) {
t.Logf("args: %s, %d\n", tc.arg, tc.n)
t.Logf("want : %x\n", want)
t.Logf("have : %x\n\n", have)
t.Fail()
}
}
{ // AddUint64
want, _ := FromBig(U256(new(big.Int).Add(bigArg, new(big.Int).SetUint64(tc.n))))
have := new(Int).AddUint64(arg, tc.n)
if !have.Eq(want) {
t.Logf("args: %s, %d\n", tc.arg, tc.n)
t.Logf("want : %x\n", want)
t.Logf("have : %x\n\n", have)
t.Fail()
}
}
}
}
Expand Down Expand Up @@ -1071,7 +1097,7 @@ func TestWriteToSlice(t *testing.T) {
t.Errorf("got %x, expected %x", dest, x1)
}

fb := NewInt()
fb := new(Int)
exp := make([]byte, 32)
fb.WriteToSlice(dest)
if !bytes.Equal(dest, exp) {
Expand Down Expand Up @@ -1191,7 +1217,7 @@ func TestByte20Representation(t *testing.T) {
exp := bytesToAddress(a.Bytes())

// uint256.Int -> address
b := NewInt().SetBytes(bytearr)
b := new(Int).SetBytes(bytearr)
got := gethAddress(b.Bytes20())

if got != exp {
Expand Down Expand Up @@ -1220,7 +1246,7 @@ func TestByte32Representation(t *testing.T) {
exp := bytesToHash(a.Bytes())

// uint256.Int -> address
b := NewInt().SetBytes(bytearr)
b := new(Int).SetBytes(bytearr)
got := gethHash(b.Bytes32())

if got != exp {
Expand Down

0 comments on commit b323bdc

Please sign in to comment.