Skip to content

Commit

Permalink
ENH: add support for list to execute (Take 2) (#573)
Browse files Browse the repository at this point in the history
* ENH: add support for list to execute

* Update rbc/omniscidb.py

Co-authored-by: Pearu Peterson <[email protected]>

* Add fromtypes to reduce a list of types

* Fix converting list to SQL array literal. Add test_direct_call_array.

* Refactor fromtypes to reducetypes and fix int16/float16

* Parametrize test with dtype

* Fix error with ints

* Unregister in tests

* Disable fail-fast

* Change function name for array tests

* REVERT THIS COMMIT: Disable vectorization in tests

* Fix mean from array_api

* Rename functions to less than 22 chars

* Rename OmnisciArrayType -> HeavyDBArrayType

* Test requires most recent version of HeavyDB

* add pytest.skip

* OmnisciArrayType -> HeavyDBArrayType

* add remote list evaluation test

* skip test when HeavyDB version < 7.0

* add test for ndarray and TextEncodingNone

* Revert "add test for ndarray and TextEncodingNone"

This reverts commit 857db58.

---------

Co-authored-by: Pamphile Roy <[email protected]>
Co-authored-by: Pearu Peterson <[email protected]>
  • Loading branch information
3 people authored Jul 14, 2023
1 parent dcfc843 commit 8da592e
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 3 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/rbc_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ on:
branches:
- main

env:
NUMBA_LOOP_VECTORIZE: 0
NUMBA_SLP_VECTORIZE: 0


# kill any previous running job on a new commit
concurrency:
group: build-and-test-rbc-${{ github.head_ref }}
Expand Down
13 changes: 12 additions & 1 deletion rbc/heavydb/remoteheavydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def type_to_type_name(typ: typesystem.Type):
).get(styp)
if type_name is not None:
return type_name

raise NotImplementedError(f'converting `{styp}` to DatumType not supported')


Expand Down Expand Up @@ -1423,7 +1424,13 @@ def get_types(self, *values):
"""
types = []
for value in values:
if isinstance(value, RemoteCallCapsule):

if isinstance(value, (list, numpy.ndarray)):
items_types = set(map(typesystem.Type.fromvalue, value))
com_type = typesystem.Type.reducetypes(items_types)
array_type = HeavyDBArrayType((com_type,))
types.append(array_type)
elif isinstance(value, RemoteCallCapsule):
typ = value.__typesystem_type__
if typ.is_struct and typ._params.get('struct_is_tuple'):
types.extend(typ)
Expand Down Expand Up @@ -1558,6 +1565,10 @@ def remote_call(self, func, ftype: typesystem.Type, arguments: tuple, hold=False
elif isinstance(a, str):
a = repr(a)
args.append(f'{a}')
elif isinstance(atype, HeavyDBArrayType):
element_type_name = type_to_type_name(atype.element_type)
astr = ", ".join([f'CAST({a_} AS {element_type_name})' for a_ in a])
args.append(f'ARRAY[{astr}]')
else:
args.append(f'CAST({a} AS {type_to_type_name(atype)})')
args = ', '.join(args)
Expand Down
18 changes: 18 additions & 0 deletions rbc/tests/heavydb/test_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def myupper(s):
r[i] = c
return r

@heavydb('int64(T[])', T=['int64', 'double'], devices=['CPU'])
def mylength(lst):
return len(lst)


def test_udf_string_repr(heavydb):
myincr = heavydb.get_caller('myincr')
Expand Down Expand Up @@ -119,6 +123,20 @@ def test_remote_TextEncodingNone_evaluation(heavydb):
assert str(myupper(b"abc").execute()) == 'ABC'


def test_remote_list_evaluation(heavydb):
if heavydb.version[:2] < (7, 0):
pytest.skip('Test requires HeavyDB 7.0 or newer')

mylength = heavydb.get_caller('mylength')
assert str(mylength) == "mylength['int64(T[]), T=int64|double, device=CPU']"
assert str(mylength([1, 2])) == \
"SELECT mylength(ARRAY[CAST(1 AS BIGINT), CAST(2 AS BIGINT)])"
assert mylength([1, 2]).execute() == 2
assert str(mylength([1.0, 2])) == \
"SELECT mylength(ARRAY[CAST(1.0 AS DOUBLE), CAST(2 AS DOUBLE)])"
assert mylength([1.0, 2]).execute() == 2


def test_remote_composite_udf_evaluation(heavydb):
myincr = heavydb.get_caller('myincr')

Expand Down
19 changes: 18 additions & 1 deletion rbc/tests/heavydb/test_heavydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from rbc.errors import UnsupportedError, HeavyDBServerError
from rbc.tests import heavydb_fixture, assert_equal
from rbc.typesystem import Type
from rbc.stdlib import array_api

rbc_heavydb = pytest.importorskip('rbc.heavydb')
available_version, reason = rbc_heavydb.is_available()
Expand Down Expand Up @@ -67,7 +68,7 @@ def heavydb():
yield o


def test_direct_call(heavydb):
def test_direct_call_scalar(heavydb):
heavydb.reset()

@heavydb('double(double)')
Expand All @@ -77,6 +78,22 @@ def farhenheit2celcius(f):
assert_equal(farhenheit2celcius(40).execute(), np.float32(40 / 9))


@pytest.mark.parametrize('dtype', ('float32', 'float64', 'int32', 'int64'))
def test_direct_call_array(heavydb, dtype):
if heavydb.version[:2] < (7, 0):
pytest.skip('Test requires HeavyDB 7.0 or newer')

heavydb.unregister()

@heavydb('T(T[])', T=['float32', 'float64', 'int64', 'int32'], devices=['cpu'])
def func(f):
return (array_api.mean(f)-32) * 5 / 9

ref_value = np.dtype(dtype).type(40 / 9)
inp_array = np.array([30, 50], dtype=dtype)
assert func(inp_array).execute() == pytest.approx(ref_value)


def test_local_caller(heavydb):
heavydb.reset()

Expand Down
27 changes: 26 additions & 1 deletion rbc/tests/test_typesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
np = None

import pytest
from rbc.typesystem import Type, get_signature
from rbc.typesystem import Type, get_signature, TypeParseError
from rbc.utils import get_datamodel
from rbc.targetinfo import TargetInfo

Expand Down Expand Up @@ -82,6 +82,31 @@ def test_commasplit():
assert '^'.join(commasplit('a[:, :], b[:, :, :]')) == 'a[:, :]^b[:, :, :]'


def test_fromtypes():

def from_list(values):
return set(map(Type.fromvalue, values))

assert Type.reducetypes(from_list([1, 3., 6])) == Type('float64')
assert Type.reducetypes(from_list([1, 3, 6])) == Type('int64')

types = from_list([np.int8(1), np.int8(3), np.int8(6)])
assert Type.reducetypes(types) == Type('int8')

types = from_list([np.int8(1), np.int32(3), np.int16(6)])
assert Type.reducetypes(types) == Type('int32')

types = from_list([np.uint8(1), np.uint32(3), np.int8(6)])
assert Type.reducetypes(types) == Type('int32')

types = from_list([np.int16(1), np.float16(3), np.int8(6)])
assert Type.reducetypes(types) == Type('float32')

msg = "Failed to cast"
with pytest.raises(TypeParseError, match=msg):
Type.reducetypes(from_list([1, 'a', 6]))


def test_fromstring(target_info):

assert Type.fromstring('void') == Type()
Expand Down
54 changes: 54 additions & 0 deletions rbc/typesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,60 @@ def fromobject(cls, obj):
return cls.fromcallable(obj)
raise NotImplementedError(repr((type(obj))))

@classmethod
def reducetypes(cls, t, *, method='highest'):
"""Reduce a list of types into a single type.
Parameters
----------
t : list(Type)
List of types to reduce to a type.
method : {common}
Reduction method. Can be:
* 'common': reduce to the highest common denominator.
Returns
-------
t : Type
New Type instance.
"""
has_uint, has_int, has_float, has_other = True, False, False, False
bit_int, bit_float = 8, 8

for item_type in t:
curr_type = item_type.tostring()

if 'int' in curr_type:
has_int = True
if 'uint' not in curr_type:
has_uint = False
curr_bit = int(curr_type.split('int')[1])
bit_int = max(bit_int, curr_bit)
elif 'float' in curr_type:
has_float = True
curr_bit = int(curr_type.split('float')[1])
bit_float = max(bit_float, curr_bit)
else:
has_other = True

if (has_int or has_float) and (not has_other) and len(t) != 1:
if not has_float:
com_type = f'int{bit_int}'
if has_uint:
com_type = 'u' + com_type
else:
if bit_float == 16 and bit_int >= 16:
bit_float = 32
com_type = f'float{bit_float}'

t = [Type.fromstring(com_type)]

if len(t) == 1:
return t.pop()
else:
raise TypeParseError(f"Failed to cast {repr(t)!r} to a single type")

def _normalize(self):
"""Return new Type instance with atomic types normalized.
"""
Expand Down

0 comments on commit 8da592e

Please sign in to comment.