diff --git a/rbc/heavydb/remoteheavydb.py b/rbc/heavydb/remoteheavydb.py index d78fe6ce..6bc4b8ac 100644 --- a/rbc/heavydb/remoteheavydb.py +++ b/rbc/heavydb/remoteheavydb.py @@ -289,6 +289,7 @@ def type_to_type_name(typ: typesystem.Type): """ styp = typ.tostring(use_annotation=False, use_name=False) type_name = dict( + bool8='BOOLEAN', int8='TINYINT', int16='SMALLINT', int32='INT', diff --git a/rbc/remotejit.py b/rbc/remotejit.py index fe1e0061..f0b56b7b 100644 --- a/rbc/remotejit.py +++ b/rbc/remotejit.py @@ -440,9 +440,9 @@ def __call__(self, *arguments, device=UNSPECIFIED, hold=UNSPECIFIED): if device is not UNSPECIFIED and device != device_: continue with target_info: - atypes = self.remotejit.get_types(*arguments) - for caller_id, caller in enumerate(self.callers): - with Type.alias(**self.remotejit.typesystem_aliases): + with Type.alias(**self.remotejit.typesystem_aliases): + atypes = self.remotejit.get_types(*arguments) + for caller_id, caller in enumerate(self.callers): ftype, penalty = caller.signature.best_match(caller.func, atypes) if ftype is None: continue diff --git a/rbc/tests/heavydb/test_caller.py b/rbc/tests/heavydb/test_caller.py index e5419f05..6678aeeb 100644 --- a/rbc/tests/heavydb/test_caller.py +++ b/rbc/tests/heavydb/test_caller.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from numpy.testing import assert_array_equal from rbc.externals.heavydb import set_output_row_size from rbc.heavydb import TextEncodingNone from rbc.tests import heavydb_fixture, assert_equal @@ -241,3 +242,29 @@ def incr_ol(x, dx): # noqa: F811 assert incr_ol(1).execute() == 2 assert incr_ol(1, 2).execute() == 3 + + +def test_remote_call_bool(heavydb): + # RBC issue 575 + if heavydb.version[:2] < (7, 0): + pytest.skip('Test requires HeavyDB 7.0 or newer') + + @heavydb('bool(bool)') + def inverse_bool(b): + return False if b else True + + assert inverse_bool(True).execute() == False # noqa: E712 + assert inverse_bool(False).execute() == True # noqa: E712 + + from rbc.stdlib import array_api + + @heavydb('bool[](bool[])') + def inv_bool_arr(arr): + sz = len(arr) + r = array_api.zeros_like(arr) + for i in range(sz): + r[i] = False if arr[i] else True + return r + + assert_array_equal(inv_bool_arr([False, True]).execute(), [True, False]) + assert_array_equal(inv_bool_arr([True, True]).execute(), [False, False]) diff --git a/rbc/typesystem.py b/rbc/typesystem.py index 64e948e9..650dba5e 100644 --- a/rbc/typesystem.py +++ b/rbc/typesystem.py @@ -255,7 +255,7 @@ def topython(self): # python_imap values must be processed with Type.fromstring _python_imap = {int: 'int64', float: 'float64', complex: 'complex128', - str: 'string', bytes: 'char*'} + str: 'string', bytes: 'char*', bool: 'bool'} # Data for the mangling algorithm, see mangle/demangle methods. #