forked from ml-explore/mlx
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdevice.cpp
More file actions
102 lines (89 loc) · 2.93 KB
/
device.cpp
File metadata and controls
102 lines (89 loc) · 2.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// Copyright © 2023-2025 Apple Inc.
#include <optional>
#include <sstream>
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/variant.h>
#include "mlx/device.h"
#include "mlx/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
void init_device(nb::module_& m) {
auto device_class = nb::class_<mx::Device>(
m, "Device", R"pbdoc(A device to run operations on.)pbdoc");
nb::enum_<mx::Device::DeviceType>(m, "DeviceType")
.value("cpu", mx::Device::DeviceType::cpu)
.value("gpu", mx::Device::DeviceType::gpu)
.export_values()
.def(
"__eq__",
[](const mx::Device::DeviceType& d, const nb::object& other) {
if (!nb::isinstance<mx::Device>(other) &&
!nb::isinstance<mx::Device::DeviceType>(other)) {
return false;
}
return d == nb::cast<mx::Device>(other);
});
device_class
.def(nb::init<mx::Device::DeviceType, int>(), "type"_a, "index"_a = 0)
.def_ro("type", &mx::Device::type)
.def(
"__repr__",
[](const mx::Device& d) {
std::ostringstream os;
os << d;
return os.str();
})
.def("__eq__", [](const mx::Device& d, const nb::object& other) {
if (!nb::isinstance<mx::Device>(other) &&
!nb::isinstance<mx::Device::DeviceType>(other)) {
return false;
}
return d == nb::cast<mx::Device>(other);
});
nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();
m.def(
"default_device",
&mx::default_device,
R"pbdoc(Get the default device.)pbdoc");
m.def(
"set_default_device",
&mx::set_default_device,
"device"_a,
R"pbdoc(Set the default device.)pbdoc");
m.def(
"is_available",
&mx::is_available,
"device"_a,
R"pbdoc(Check if a back-end is available for the given device.)pbdoc");
m.def(
"device_count",
&mx::device_count,
"device_type"_a,
R"pbdoc(
Get the number of available devices for the given device type.
Args:
device_type (DeviceType): The type of device to query (cpu or gpu).
Returns:
int: Number of devices.
)pbdoc");
m.def(
"device_info",
[](std::optional<mx::Device> d) {
return mx::device_info(d.value_or(mx::default_device()));
},
"d"_a = nb::none(),
R"pbdoc(
Get information about a device.
Returns a dictionary with device properties. Available keys depend
on the backend and device type. Common keys include ``device_name``,
``architecture``, and ``total_memory`` (or ``memory_size``).
Args:
d (Device): The device to query (defaults to the default device).
Returns:
dict: Device information.
)pbdoc");
}