Skip to content

Commit 677bd6d

Browse files
committed
reuse backend decision logic in tests + some nitpicks
1 parent 9b72cac commit 677bd6d

2 files changed

Lines changed: 32 additions & 31 deletions

File tree

cuda_core/cuda/core/experimental/_linker.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,43 @@
2020
_nvjitlink_input_types = None # populated if nvJitLink cannot be used
2121

2222

23-
def _lazy_init():
24-
global _inited
25-
if _inited:
23+
# Note: this function is reused in the tests
24+
def _decide_nvjitlink_or_driver():
25+
"""Returns True if falling back to the cuLink* driver APIs."""
26+
global _driver_ver, _driver, _nvjitlink
27+
if _driver or _nvjitlink:
2628
return
2729

28-
global _driver, _driver_input_types, _driver_ver, _nvjitlink, _nvjitlink_input_types
2930
_driver_ver = handle_return(cuda.cuDriverGetVersion())
3031
_driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10)
3132
try:
32-
from cuda.bindings import nvjitlink
33+
from cuda.bindings import nvjitlink as _nvjitlink
3334
from cuda.bindings._internal import nvjitlink as inner_nvjitlink
3435
except ImportError:
3536
# binding is not available
36-
nvjitlink = None
37+
_nvjitlink = None
3738
else:
3839
if inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") == 0:
3940
# binding is available, but nvJitLink is not installed
40-
nvjitlink = None
41-
elif _driver_ver > nvjitlink.version():
41+
_nvjitlink = None
42+
43+
if _nvjitlink is None:
44+
_driver = cuda
45+
return True
46+
else:
47+
return False
48+
49+
50+
def _lazy_init():
51+
global _inited, _nvjitlink_input_types, _driver_input_types
52+
if _inited:
53+
return
54+
55+
_decide_nvjitlink_or_driver()
56+
if _nvjitlink:
57+
if _driver_ver > _nvjitlink.version():
4258
# TODO: nvJitLink is not new enough, warn?
4359
pass
44-
if nvjitlink:
45-
_nvjitlink = nvjitlink
4660
_nvjitlink_input_types = {
4761
"ptx": _nvjitlink.InputType.PTX,
4862
"cubin": _nvjitlink.InputType.CUBIN,
@@ -51,8 +65,6 @@ def _lazy_init():
5165
"object": _nvjitlink.InputType.OBJECT,
5266
}
5367
else:
54-
from cuda import cuda as _driver
55-
5668
_driver_input_types = {
5769
"ptx": _driver.CUjitInputType.CU_JIT_INPUT_PTX,
5870
"cubin": _driver.CUjitInputType.CU_JIT_INPUT_CUBIN,

cuda_core/tests/test_linker.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,37 @@
11
import pytest
22

3-
from cuda.core.experimental import Linker, LinkerOptions, Program
3+
from cuda.core.experimental import Linker, LinkerOptions, Program, _linker
44
from cuda.core.experimental._module import ObjectCode
55

66
ARCH = "sm_80" # use sm_80 for testing the oop nvJitLink wrapper
77

8-
9-
device_function_a = """
10-
__device__ int B();
11-
__device__ int C(int a, int b);
8+
kernel_a = """
9+
extern __device__ int B();
10+
extern __device__ int C(int a, int b);
1211
__global__ void A() { int result = C(B(), 1);}
1312
"""
1413
device_function_b = "__device__ int B() { return 0; }"
1514
device_function_c = "__device__ int C(int a, int b) { return a + b; }"
1615

17-
culink_backend = False
18-
try:
19-
from cuda.bindings import nvjitlink # noqa F401
20-
from cuda.bindings._internal import nvjitlink as inner_nvjitlink
21-
except ImportError:
22-
# binding is not available
23-
culink_backend = True
24-
else:
25-
if inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") == 0:
26-
# binding is available, but nvJitLink is not installed
27-
culink_backend = True
16+
culink_backend = _linker._decide_nvjitlink_or_driver()
2817

2918

3019
@pytest.fixture(scope="function")
3120
def compile_ptx_functions(init_cuda):
32-
# Without rdc (relocatable device code) option, the generated ptx will not included any unreferenced
21+
# Without -rdc (relocatable device code) option, the generated ptx will not included any unreferenced
3322
# device functions, causing the link to fail
23+
object_code_a_ptx = Program(kernel_a, "c++").compile("ptx", options=("-rdc=true",))
3424
object_code_b_ptx = Program(device_function_b, "c++").compile("ptx", options=("-rdc=true",))
3525
object_code_c_ptx = Program(device_function_c, "c++").compile("ptx", options=("-rdc=true",))
36-
object_code_a_ptx = Program(device_function_a, "c++").compile("ptx", options=("-rdc=true",))
3726

3827
return object_code_a_ptx, object_code_b_ptx, object_code_c_ptx
3928

4029

4130
@pytest.fixture(scope="function")
4231
def compile_ltoir_functions(init_cuda):
32+
object_code_a_ltoir = Program(kernel_a, "c++").compile("ltoir", options=("-dlto",))
4333
object_code_b_ltoir = Program(device_function_b, "c++").compile("ltoir", options=("-dlto",))
4434
object_code_c_ltoir = Program(device_function_c, "c++").compile("ltoir", options=("-dlto",))
45-
object_code_a_ltoir = Program(device_function_a, "c++").compile("ltoir", options=("-dlto",))
4635

4736
return object_code_a_ltoir, object_code_b_ltoir, object_code_c_ltoir
4837

0 commit comments

Comments
 (0)