|
1 | 1 | import pytest |
2 | 2 |
|
3 | | -from cuda.core.experimental import Linker, LinkerOptions, Program |
| 3 | +from cuda.core.experimental import Linker, LinkerOptions, Program, _linker |
4 | 4 | from cuda.core.experimental._module import ObjectCode |
5 | 5 |
|
6 | 6 | ARCH = "sm_80" # use sm_80 for testing the oop nvJitLink wrapper |
7 | 7 |
|
8 | | - |
9 | | -device_function_a = """ |
10 | | -__device__ int B(); |
11 | | -__device__ int C(int a, int b); |
| 8 | +kernel_a = """ |
| 9 | +extern __device__ int B(); |
| 10 | +extern __device__ int C(int a, int b); |
12 | 11 | __global__ void A() { int result = C(B(), 1);} |
13 | 12 | """ |
14 | 13 | device_function_b = "__device__ int B() { return 0; }" |
15 | 14 | device_function_c = "__device__ int C(int a, int b) { return a + b; }" |
16 | 15 |
|
17 | | -culink_backend = False |
18 | | -try: |
19 | | - from cuda.bindings import nvjitlink # noqa F401 |
20 | | - from cuda.bindings._internal import nvjitlink as inner_nvjitlink |
21 | | -except ImportError: |
22 | | - # binding is not available |
23 | | - culink_backend = True |
24 | | -else: |
25 | | - if inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") == 0: |
26 | | - # binding is available, but nvJitLink is not installed |
27 | | - culink_backend = True |
| 16 | +culink_backend = _linker._decide_nvjitlink_or_driver() |
28 | 17 |
|
29 | 18 |
|
30 | 19 | @pytest.fixture(scope="function") |
31 | 20 | def compile_ptx_functions(init_cuda): |
32 | | - # Without rdc (relocatable device code) option, the generated ptx will not included any unreferenced |
| 21 | + # Without -rdc (relocatable device code) option, the generated ptx will not included any unreferenced |
33 | 22 | # device functions, causing the link to fail |
| 23 | + object_code_a_ptx = Program(kernel_a, "c++").compile("ptx", options=("-rdc=true",)) |
34 | 24 | object_code_b_ptx = Program(device_function_b, "c++").compile("ptx", options=("-rdc=true",)) |
35 | 25 | object_code_c_ptx = Program(device_function_c, "c++").compile("ptx", options=("-rdc=true",)) |
36 | | - object_code_a_ptx = Program(device_function_a, "c++").compile("ptx", options=("-rdc=true",)) |
37 | 26 |
|
38 | 27 | return object_code_a_ptx, object_code_b_ptx, object_code_c_ptx |
39 | 28 |
|
40 | 29 |
|
41 | 30 | @pytest.fixture(scope="function") |
42 | 31 | def compile_ltoir_functions(init_cuda): |
| 32 | + object_code_a_ltoir = Program(kernel_a, "c++").compile("ltoir", options=("-dlto",)) |
43 | 33 | object_code_b_ltoir = Program(device_function_b, "c++").compile("ltoir", options=("-dlto",)) |
44 | 34 | object_code_c_ltoir = Program(device_function_c, "c++").compile("ltoir", options=("-dlto",)) |
45 | | - object_code_a_ltoir = Program(device_function_a, "c++").compile("ltoir", options=("-dlto",)) |
46 | 35 |
|
47 | 36 | return object_code_a_ltoir, object_code_b_ltoir, object_code_c_ltoir |
48 | 37 |
|
|
0 commit comments