From 8da592ec72a5c9b2258a04cfb74384463bad3e5d Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 14 Jul 2023 17:51:15 -0300 Subject: [PATCH] ENH: add support for list to execute (Take 2) (#573) * ENH: add support for list to execute * Update rbc/omniscidb.py Co-authored-by: Pearu Peterson * Add fromtypes to reduce a list of types * Fix converting list to SQL array literal. Add test_direct_call_array. * Refactor fromtypes to reducetypes and fix int16/float16 * Parametrize test with dtype * Fix error with ints * Unregister in tests * Disable fail-fast * Change function name for array tests * REVERT THIS COMMIT: Disable vectorization in tests * Fix mean from array_api * Rename functions to less than 22 chars * Rename OmnisciArrayType -> HeavyDBArrayType * Test requires most recent version of HeavyDB * add pytest.skip * OmnisciArrayType -> HeavyDBArrayType * add remote list evaluation test * skip test when HeavyDB version < 7.0 * add test for ndarray and TextEncodingNone * Revert "add test for ndarray and TextEncodingNone" This reverts commit 857db58f642c5957b270d69d55bbec8ed358fe55. --------- Co-authored-by: Pamphile Roy Co-authored-by: Pearu Peterson --- .github/workflows/rbc_test.yml | 5 +++ rbc/heavydb/remoteheavydb.py | 13 +++++++- rbc/tests/heavydb/test_caller.py | 18 +++++++++++ rbc/tests/heavydb/test_heavydb.py | 19 ++++++++++- rbc/tests/test_typesystem.py | 27 +++++++++++++++- rbc/typesystem.py | 54 +++++++++++++++++++++++++++++++ 6 files changed, 133 insertions(+), 3 deletions(-) diff --git a/.github/workflows/rbc_test.yml b/.github/workflows/rbc_test.yml index 3b8a9e7f1..65337e5dc 100644 --- a/.github/workflows/rbc_test.yml +++ b/.github/workflows/rbc_test.yml @@ -7,6 +7,11 @@ on: branches: - main +env: + NUMBA_LOOP_VECTORIZE: 0 + NUMBA_SLP_VECTORIZE: 0 + + # kill any previous running job on a new commit concurrency: group: build-and-test-rbc-${{ github.head_ref }} diff --git a/rbc/heavydb/remoteheavydb.py b/rbc/heavydb/remoteheavydb.py index 5c3d69c6b..d78fe6cec 100644 --- a/rbc/heavydb/remoteheavydb.py +++ b/rbc/heavydb/remoteheavydb.py @@ -298,6 +298,7 @@ def type_to_type_name(typ: typesystem.Type): ).get(styp) if type_name is not None: return type_name + raise NotImplementedError(f'converting `{styp}` to DatumType not supported') @@ -1423,7 +1424,13 @@ def get_types(self, *values): """ types = [] for value in values: - if isinstance(value, RemoteCallCapsule): + + if isinstance(value, (list, numpy.ndarray)): + items_types = set(map(typesystem.Type.fromvalue, value)) + com_type = typesystem.Type.reducetypes(items_types) + array_type = HeavyDBArrayType((com_type,)) + types.append(array_type) + elif isinstance(value, RemoteCallCapsule): typ = value.__typesystem_type__ if typ.is_struct and typ._params.get('struct_is_tuple'): types.extend(typ) @@ -1558,6 +1565,10 @@ def remote_call(self, func, ftype: typesystem.Type, arguments: tuple, hold=False elif isinstance(a, str): a = repr(a) args.append(f'{a}') + elif isinstance(atype, HeavyDBArrayType): + element_type_name = type_to_type_name(atype.element_type) + astr = ", ".join([f'CAST({a_} AS {element_type_name})' for a_ in a]) + args.append(f'ARRAY[{astr}]') else: args.append(f'CAST({a} AS {type_to_type_name(atype)})') args = ', '.join(args) diff --git a/rbc/tests/heavydb/test_caller.py b/rbc/tests/heavydb/test_caller.py index 814624b5e..e5419f05d 100644 --- a/rbc/tests/heavydb/test_caller.py +++ b/rbc/tests/heavydb/test_caller.py @@ -47,6 +47,10 @@ def myupper(s): r[i] = c return r + @heavydb('int64(T[])', T=['int64', 'double'], devices=['CPU']) + def mylength(lst): + return len(lst) + def test_udf_string_repr(heavydb): myincr = heavydb.get_caller('myincr') @@ -119,6 +123,20 @@ def test_remote_TextEncodingNone_evaluation(heavydb): assert str(myupper(b"abc").execute()) == 'ABC' +def test_remote_list_evaluation(heavydb): + if heavydb.version[:2] < (7, 0): + pytest.skip('Test requires HeavyDB 7.0 or newer') + + mylength = heavydb.get_caller('mylength') + assert str(mylength) == "mylength['int64(T[]), T=int64|double, device=CPU']" + assert str(mylength([1, 2])) == \ + "SELECT mylength(ARRAY[CAST(1 AS BIGINT), CAST(2 AS BIGINT)])" + assert mylength([1, 2]).execute() == 2 + assert str(mylength([1.0, 2])) == \ + "SELECT mylength(ARRAY[CAST(1.0 AS DOUBLE), CAST(2 AS DOUBLE)])" + assert mylength([1.0, 2]).execute() == 2 + + def test_remote_composite_udf_evaluation(heavydb): myincr = heavydb.get_caller('myincr') diff --git a/rbc/tests/heavydb/test_heavydb.py b/rbc/tests/heavydb/test_heavydb.py index eb48dfcad..8f709cebb 100644 --- a/rbc/tests/heavydb/test_heavydb.py +++ b/rbc/tests/heavydb/test_heavydb.py @@ -6,6 +6,7 @@ from rbc.errors import UnsupportedError, HeavyDBServerError from rbc.tests import heavydb_fixture, assert_equal from rbc.typesystem import Type +from rbc.stdlib import array_api rbc_heavydb = pytest.importorskip('rbc.heavydb') available_version, reason = rbc_heavydb.is_available() @@ -67,7 +68,7 @@ def heavydb(): yield o -def test_direct_call(heavydb): +def test_direct_call_scalar(heavydb): heavydb.reset() @heavydb('double(double)') @@ -77,6 +78,22 @@ def farhenheit2celcius(f): assert_equal(farhenheit2celcius(40).execute(), np.float32(40 / 9)) +@pytest.mark.parametrize('dtype', ('float32', 'float64', 'int32', 'int64')) +def test_direct_call_array(heavydb, dtype): + if heavydb.version[:2] < (7, 0): + pytest.skip('Test requires HeavyDB 7.0 or newer') + + heavydb.unregister() + + @heavydb('T(T[])', T=['float32', 'float64', 'int64', 'int32'], devices=['cpu']) + def func(f): + return (array_api.mean(f)-32) * 5 / 9 + + ref_value = np.dtype(dtype).type(40 / 9) + inp_array = np.array([30, 50], dtype=dtype) + assert func(inp_array).execute() == pytest.approx(ref_value) + + def test_local_caller(heavydb): heavydb.reset() diff --git a/rbc/tests/test_typesystem.py b/rbc/tests/test_typesystem.py index b90f12724..0fea032ef 100644 --- a/rbc/tests/test_typesystem.py +++ b/rbc/tests/test_typesystem.py @@ -18,7 +18,7 @@ np = None import pytest -from rbc.typesystem import Type, get_signature +from rbc.typesystem import Type, get_signature, TypeParseError from rbc.utils import get_datamodel from rbc.targetinfo import TargetInfo @@ -82,6 +82,31 @@ def test_commasplit(): assert '^'.join(commasplit('a[:, :], b[:, :, :]')) == 'a[:, :]^b[:, :, :]' +def test_fromtypes(): + + def from_list(values): + return set(map(Type.fromvalue, values)) + + assert Type.reducetypes(from_list([1, 3., 6])) == Type('float64') + assert Type.reducetypes(from_list([1, 3, 6])) == Type('int64') + + types = from_list([np.int8(1), np.int8(3), np.int8(6)]) + assert Type.reducetypes(types) == Type('int8') + + types = from_list([np.int8(1), np.int32(3), np.int16(6)]) + assert Type.reducetypes(types) == Type('int32') + + types = from_list([np.uint8(1), np.uint32(3), np.int8(6)]) + assert Type.reducetypes(types) == Type('int32') + + types = from_list([np.int16(1), np.float16(3), np.int8(6)]) + assert Type.reducetypes(types) == Type('float32') + + msg = "Failed to cast" + with pytest.raises(TypeParseError, match=msg): + Type.reducetypes(from_list([1, 'a', 6])) + + def test_fromstring(target_info): assert Type.fromstring('void') == Type() diff --git a/rbc/typesystem.py b/rbc/typesystem.py index 3e6b7a61c..64e948e9a 100644 --- a/rbc/typesystem.py +++ b/rbc/typesystem.py @@ -1244,6 +1244,60 @@ def fromobject(cls, obj): return cls.fromcallable(obj) raise NotImplementedError(repr((type(obj)))) + @classmethod + def reducetypes(cls, t, *, method='highest'): + """Reduce a list of types into a single type. + + Parameters + ---------- + t : list(Type) + List of types to reduce to a type. + method : {common} + Reduction method. Can be: + + * 'common': reduce to the highest common denominator. + + Returns + ------- + t : Type + New Type instance. + """ + has_uint, has_int, has_float, has_other = True, False, False, False + bit_int, bit_float = 8, 8 + + for item_type in t: + curr_type = item_type.tostring() + + if 'int' in curr_type: + has_int = True + if 'uint' not in curr_type: + has_uint = False + curr_bit = int(curr_type.split('int')[1]) + bit_int = max(bit_int, curr_bit) + elif 'float' in curr_type: + has_float = True + curr_bit = int(curr_type.split('float')[1]) + bit_float = max(bit_float, curr_bit) + else: + has_other = True + + if (has_int or has_float) and (not has_other) and len(t) != 1: + if not has_float: + com_type = f'int{bit_int}' + if has_uint: + com_type = 'u' + com_type + else: + if bit_float == 16 and bit_int >= 16: + bit_float = 32 + com_type = f'float{bit_float}' + + t = [Type.fromstring(com_type)] + + if len(t) == 1: + return t.pop() + else: + raise TypeParseError(f"Failed to cast {repr(t)!r} to a single type") + def _normalize(self): """Return new Type instance with atomic types normalized. """