forked from arrayfire-community/arrayfire_python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcore.py
More file actions
27 lines (24 loc) · 692 Bytes
/
core.py
File metadata and controls
27 lines (24 loc) · 692 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from . import array
from .array import *
from . import helper
from .helper import *
def dot(A, B):
res = ndarray(shape=(A.shape[0], B.shape[1]), dtype=A.dtype)
res.ptr = libaf.matmul(A.ptr, B.ptr)
if (res.ptr == 0):
raise Exception("Failure in dot")
return res
def sum(A, dim=None):
if dim is None:
sum_all = libaf.sumAll
sum_all.restype = c_double
res = libaf.sumAll(A.ptr)
else:
lst = list(A.shape)
lst[dim] = 1
shape = strip(lst)
res = ndarray(shape, dtype=A.dtype)
res.ptr = libaf.sum(A.ptr, dim)
if (res.ptr == 0):
raise Exception("Failure in dot")
return res