diff --git a/rbc/irtools.py b/rbc/irtools.py index 99ca5cce..471cd42a 100644 --- a/rbc/irtools.py +++ b/rbc/irtools.py @@ -5,7 +5,7 @@ import numba as nb from llvmlite import ir import llvmlite.binding as llvm -from .utils import is_localhost +from .utils import is_localhost, triple_matches def initialize_llvm(): @@ -61,16 +61,13 @@ def compile_to_LLVM(functions_and_signatures, target, server=None, LLVM module instance. To get the IR string, use `str(module)`. """ - cpu_target = llvm.get_process_triple() if server is None or is_localhost(server.host): - if target == 'host' or target == cpu_target: - # FYI, there is also get_process_triple() - # triple = llvm.get_default_triple() + if triple_matches(target, 'host'): target_desc = nb.targets.registry.cpu_target typing_context = target_desc.typing_context target_context = target_desc.target_context use_host_target = False - elif target == 'cuda' or target == 'nvptx64-nvidia-cuda': + elif triple_matches(target, 'cuda'): if use_host_target: triple = 'nvptx64-nvidia-cuda' data_layout = nb.cuda.cudadrv.nvvm.data_layout[64] @@ -81,7 +78,7 @@ def compile_to_LLVM(functions_and_signatures, target, server=None, target_desc = nb.cuda.descriptor.CUDATargetDesc typing_context = target_desc.typingctx target_context = target_desc.targetctx - elif target == 'cuda32' or target == 'nvptx-nvidia-cuda': + elif triple_matches(target, 'cuda32'): if use_host_target: triple = 'nvptx-nvidia-cuda' data_layout = nb.cuda.cudadrv.nvvm.data_layout[32] @@ -93,6 +90,7 @@ def compile_to_LLVM(functions_and_signatures, target, server=None, typing_context = target_desc.typingctx target_context = target_desc.targetctx else: + cpu_target = llvm.get_process_triple() raise NotImplementedError(repr((target, cpu_target))) else: raise NotImplementedError(repr((target, server))) diff --git a/rbc/tests/test_utils.py b/rbc/tests/test_utils.py index e71dd03e..4f5331d3 100644 --- a/rbc/tests/test_utils.py +++ b/rbc/tests/test_utils.py @@ -1,5 +1,13 @@ -from rbc.utils import is_localhost, get_local_ip +from rbc.utils import is_localhost, get_local_ip, triple_matches def test_is_localhost(): assert is_localhost(get_local_ip()) + + +def test_triple_matches(): + assert triple_matches('cuda', 'nvptx64-nvidia-cuda') + assert triple_matches('nvptx64-nvidia-cuda', 'cuda') + assert triple_matches('cuda32', 'nvptx-nvidia-cuda') + assert triple_matches('nvptx-nvidia-cuda', 'cuda32') + assert triple_matches('x86_64-pc-linux-gnu', 'x86_64-unknown-linux-gnu') diff --git a/rbc/utils.py b/rbc/utils.py index cde6fd4e..5ce9575a 100644 --- a/rbc/utils.py +++ b/rbc/utils.py @@ -4,6 +4,7 @@ import netifaces import uuid import ctypes +import llvmlite.binding as llvm def get_local_ip(): @@ -81,3 +82,34 @@ def get_datamodel(): (16, 64, 64, 64, 64): 'ILP64', # HAL (64, 64, 64, 64, 64): 'SILP64', # UNICOS }[short_sizeof, int_sizeof, long_sizeof, ptr_sizeof, longlong_sizeof] + + +def triple_split(triple): + """Split target triple into parts. + """ + arch, vendor, os = triple.split('-', 2) + if '-' in os: + os, env = os.split('-', 1) + else: + env = '' + return arch, vendor, os, env + + +def triple_matches(triple, other): + """Check if target triples match. + """ + if triple == other: + return True + if triple == 'cuda': + return triple_matches('nvptx64-nvidia-cuda', other) + if triple == 'cuda32': + return triple_matches('nvptx-nvidia-cuda', other) + if triple == 'host': + return triple_matches(llvm.get_process_triple(), other) + if other in ['cuda', 'cuda32', 'host']: + return triple_matches(other, triple) + arch1, vendor1, os1, env1 = triple_split(triple) + arch2, vendor2, os2, env2 = triple_split(other) + if os1 == os2 == 'linux': + return (arch1, env1) == (arch2, env2) + return (arch1, vendor1, os1, env1) == (arch2, vendor2, os2, env2)