Skip to content

Commit

Permalink
Merge pull request #7 from xnd-project/pearu/target_match
Browse files Browse the repository at this point in the history
Use triple_matches for choosing target
  • Loading branch information
pearu authored Oct 18, 2019
2 parents 468011e + 3407d7f commit 744ad50
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 8 deletions.
12 changes: 5 additions & 7 deletions rbc/irtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numba as nb
from llvmlite import ir
import llvmlite.binding as llvm
from .utils import is_localhost
from .utils import is_localhost, triple_matches


def initialize_llvm():
Expand Down Expand Up @@ -61,16 +61,13 @@ def compile_to_LLVM(functions_and_signatures, target, server=None,
LLVM module instance. To get the IR string, use `str(module)`.
"""
cpu_target = llvm.get_process_triple()
if server is None or is_localhost(server.host):
if target == 'host' or target == cpu_target:
# FYI, there is also get_process_triple()
# triple = llvm.get_default_triple()
if triple_matches(target, 'host'):
target_desc = nb.targets.registry.cpu_target
typing_context = target_desc.typing_context
target_context = target_desc.target_context
use_host_target = False
elif target == 'cuda' or target == 'nvptx64-nvidia-cuda':
elif triple_matches(target, 'cuda'):
if use_host_target:
triple = 'nvptx64-nvidia-cuda'
data_layout = nb.cuda.cudadrv.nvvm.data_layout[64]
Expand All @@ -81,7 +78,7 @@ def compile_to_LLVM(functions_and_signatures, target, server=None,
target_desc = nb.cuda.descriptor.CUDATargetDesc
typing_context = target_desc.typingctx
target_context = target_desc.targetctx
elif target == 'cuda32' or target == 'nvptx-nvidia-cuda':
elif triple_matches(target, 'cuda32'):
if use_host_target:
triple = 'nvptx-nvidia-cuda'
data_layout = nb.cuda.cudadrv.nvvm.data_layout[32]
Expand All @@ -93,6 +90,7 @@ def compile_to_LLVM(functions_and_signatures, target, server=None,
typing_context = target_desc.typingctx
target_context = target_desc.targetctx
else:
cpu_target = llvm.get_process_triple()
raise NotImplementedError(repr((target, cpu_target)))
else:
raise NotImplementedError(repr((target, server)))
Expand Down
10 changes: 9 additions & 1 deletion rbc/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from rbc.utils import is_localhost, get_local_ip
from rbc.utils import is_localhost, get_local_ip, triple_matches


def test_is_localhost():
assert is_localhost(get_local_ip())


def test_triple_matches():
assert triple_matches('cuda', 'nvptx64-nvidia-cuda')
assert triple_matches('nvptx64-nvidia-cuda', 'cuda')
assert triple_matches('cuda32', 'nvptx-nvidia-cuda')
assert triple_matches('nvptx-nvidia-cuda', 'cuda32')
assert triple_matches('x86_64-pc-linux-gnu', 'x86_64-unknown-linux-gnu')
32 changes: 32 additions & 0 deletions rbc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import netifaces
import uuid
import ctypes
import llvmlite.binding as llvm


def get_local_ip():
Expand Down Expand Up @@ -81,3 +82,34 @@ def get_datamodel():
(16, 64, 64, 64, 64): 'ILP64', # HAL
(64, 64, 64, 64, 64): 'SILP64', # UNICOS
}[short_sizeof, int_sizeof, long_sizeof, ptr_sizeof, longlong_sizeof]


def triple_split(triple):
"""Split target triple into parts.
"""
arch, vendor, os = triple.split('-', 2)
if '-' in os:
os, env = os.split('-', 1)
else:
env = ''
return arch, vendor, os, env


def triple_matches(triple, other):
"""Check if target triples match.
"""
if triple == other:
return True
if triple == 'cuda':
return triple_matches('nvptx64-nvidia-cuda', other)
if triple == 'cuda32':
return triple_matches('nvptx-nvidia-cuda', other)
if triple == 'host':
return triple_matches(llvm.get_process_triple(), other)
if other in ['cuda', 'cuda32', 'host']:
return triple_matches(other, triple)
arch1, vendor1, os1, env1 = triple_split(triple)
arch2, vendor2, os2, env2 = triple_split(other)
if os1 == os2 == 'linux':
return (arch1, env1) == (arch2, env2)
return (arch1, vendor1, os1, env1) == (arch2, vendor2, os2, env2)

0 comments on commit 744ad50

Please sign in to comment.