forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathc_api_util.py
More file actions
150 lines (116 loc) · 4.15 KB
/
c_api_util.py
File metadata and controls
150 lines (116 loc) · 4.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for using the TensorFlow C API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.util import compat
from tensorflow.python.util import tf_contextlib
class ScopedTFStatus(object):
"""Wrapper around TF_Status that handles deletion."""
def __init__(self):
self.status = c_api.TF_NewStatus()
def __del__(self):
# Note: when we're destructing the global context (i.e when the process is
# terminating) we can have already deleted other modules.
if c_api.TF_DeleteStatus is not None:
c_api.TF_DeleteStatus(self.status)
class ScopedTFGraph(object):
"""Wrapper around TF_Graph that handles deletion."""
def __init__(self):
self.graph = c_api.TF_NewGraph()
def __del__(self):
# Note: when we're destructing the global context (i.e when the process is
# terminating) we can have already deleted other modules.
if c_api.TF_DeleteGraph is not None:
c_api.TF_DeleteGraph(self.graph)
class ScopedTFImportGraphDefOptions(object):
"""Wrapper around TF_ImportGraphDefOptions that handles deletion."""
def __init__(self):
self.options = c_api.TF_NewImportGraphDefOptions()
def __del__(self):
# Note: when we're destructing the global context (i.e when the process is
# terminating) we can have already deleted other modules.
if c_api.TF_DeleteImportGraphDefOptions is not None:
c_api.TF_DeleteImportGraphDefOptions(self.options)
@tf_contextlib.contextmanager
def tf_buffer(data=None):
"""Context manager that creates and deletes TF_Buffer.
Example usage:
with tf_buffer() as buf:
# get serialized graph def into buf
...
proto_data = c_api.TF_GetBuffer(buf)
graph_def.ParseFromString(compat.as_bytes(proto_data))
# buf has been deleted
with tf_buffer(some_string) as buf:
c_api.TF_SomeFunction(buf)
# buf has been deleted
Args:
data: An optional `bytes`, `str`, or `unicode` object. If not None, the
yielded buffer will contain this data.
Yields:
Created TF_Buffer
"""
if data:
buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
else:
buf = c_api.TF_NewBuffer()
try:
yield buf
finally:
c_api.TF_DeleteBuffer(buf)
def tf_output(c_op, index):
"""Returns a wrapped TF_Output with specified operation and index.
Args:
c_op: wrapped TF_Operation
index: integer
Returns:
Wrapped TF_Output
"""
ret = c_api.TF_Output()
ret.oper = c_op
ret.index = index
return ret
def tf_operations(graph):
"""Generator that yields every TF_Operation in `graph`.
Args:
graph: Graph
Yields:
wrapped TF_Operation
"""
# pylint: disable=protected-access
pos = 0
c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
while c_op is not None:
yield c_op
c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
# pylint: enable=protected-access
def new_tf_operations(graph):
"""Generator that yields newly-added TF_Operations in `graph`.
Specifically, yields TF_Operations that don't have associated Operations in
`graph`. This is useful for processing nodes added by the C API.
Args:
graph: Graph
Yields:
wrapped TF_Operation
"""
# TODO(b/69679162): do this more efficiently
for c_op in tf_operations(graph):
try:
graph._get_operation_by_tf_operation(c_op) # pylint: disable=protected-access
except KeyError:
yield c_op