diff --git a/dataconv/types/hybrid.go b/dataconv/types/hybrid.go index bc98f6a..82c8283 100644 --- a/dataconv/types/hybrid.go +++ b/dataconv/types/hybrid.go @@ -3,6 +3,7 @@ package types import ( "fmt" + "math" "go.starlark.net/starlark" ) @@ -53,17 +54,38 @@ func (p FloatOrInt) GoFloat64() float64 { // GoInt returns the Go int representation of the FloatOrInt. func (p FloatOrInt) GoInt() int { - return int(p) + f := float64(p) + if f < float64(math.MinInt) || f > float64(math.MaxInt) { + if f < 0 { + return math.MinInt + } + return math.MaxInt + } + return int(f) } // GoInt32 returns the Go int32 representation of the FloatOrInt. func (p FloatOrInt) GoInt32() int32 { - return int32(p) + f := float64(p) + if f < float64(math.MinInt32) || f > float64(math.MaxInt32) { + if f < 0 { + return math.MinInt32 + } + return math.MaxInt32 + } + return int32(f) } // GoInt64 returns the Go int64 representation of the FloatOrInt. func (p FloatOrInt) GoInt64() int64 { - return int64(p) + f := float64(p) + if f < float64(math.MinInt64) || f > float64(math.MaxInt64) { + if f < 0 { + return math.MinInt64 + } + return math.MaxInt64 + } + return int64(f) } // NumericValue holds a Starlark numeric value and tracks its type. diff --git a/dataconv/types/hybrid_test.go b/dataconv/types/hybrid_test.go index 75ea5d9..d975b1b 100644 --- a/dataconv/types/hybrid_test.go +++ b/dataconv/types/hybrid_test.go @@ -48,6 +48,7 @@ func TestFloatOrInt_Value(t *testing.T) { v FloatOrInt wantInt int wantInt32 int32 + wantInt64 int64 wantFlt float64 }{ { @@ -55,6 +56,7 @@ func TestFloatOrInt_Value(t *testing.T) { v: 0, wantInt: 0, wantInt32: 0, + wantInt64: 0, wantFlt: 0, }, { @@ -62,6 +64,7 @@ func TestFloatOrInt_Value(t *testing.T) { v: 1, wantInt: 1, wantInt32: 1, + wantInt64: 1, wantFlt: 1, }, { @@ -69,6 +72,7 @@ func TestFloatOrInt_Value(t *testing.T) { v: 1.2, wantInt: 1, wantInt32: 1, + wantInt64: 1, wantFlt: 1.2, }, { @@ -76,8 +80,17 @@ func TestFloatOrInt_Value(t *testing.T) { v: 1e12 + 1, wantInt: 1000000000001, wantInt32: 2147483647, + wantInt64: 1000000000001, wantFlt: 1e12 + 1, }, + { + name: "underflow", + v: -1e12 - 1, + wantInt: -1000000000001, + wantInt32: -2147483648, + wantInt64: -1000000000001, + wantFlt: -1e12 - 1, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -87,8 +100,8 @@ func TestFloatOrInt_Value(t *testing.T) { if got := tt.v.GoInt32(); got != tt.wantInt32 { t.Errorf("FloatOrInt.GoInt32() = %v, want %v", got, tt.wantInt32) } - if got := tt.v.GoInt64(); got != int64(tt.wantInt) { - t.Errorf("FloatOrInt.GoInt64() = %v, want %v", got, int64(tt.wantInt)) + if got := tt.v.GoInt64(); got != tt.wantInt64 { + t.Errorf("FloatOrInt.GoInt64() = %v, want %v", got, tt.wantInt64) } if got := tt.v.GoFloat(); got != tt.wantFlt {