Skip to content

Commit 809c12c

Browse files
Yao Zhangtensorflower-gardener
authored andcommitted
Disable constant folding for more tf.dbg tests, as they make assumptions on the
underlying graph structure. PiperOrigin-RevId: 169484497
1 parent 0baf077 commit 809c12c

2 files changed

Lines changed: 12 additions & 2 deletions

File tree

tensorflow/python/debug/cli/analyzer_cli_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@
4747

4848
def no_rewrite_session_config():
4949
rewriter_config = rewriter_config_pb2.RewriterConfig(
50-
disable_model_pruning=True)
50+
disable_model_pruning=True,
51+
constant_folding=rewriter_config_pb2.RewriterConfig.OFF)
5152
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
5253
return config_pb2.ConfigProto(graph_options=graph_options)
5354

tensorflow/python/debug/cli/profile_analyzer_cli_test.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from tensorflow.core.framework import step_stats_pb2
2424
from tensorflow.core.protobuf import config_pb2
25+
from tensorflow.core.protobuf import rewriter_config_pb2
2526
from tensorflow.python.client import session
2627
from tensorflow.python.debug.cli import debugger_cli_common
2728
from tensorflow.python.debug.cli import profile_analyzer_cli
@@ -35,6 +36,14 @@
3536
from tensorflow.python.util import tf_inspect
3637

3738

39+
def no_rewrite_session_config():
40+
rewriter_config = rewriter_config_pb2.RewriterConfig(
41+
disable_model_pruning=True,
42+
constant_folding=rewriter_config_pb2.RewriterConfig.OFF)
43+
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
44+
return config_pb2.ConfigProto(graph_options=graph_options)
45+
46+
3847
def _line_number_above():
3948
return tf_inspect.stack()[1][2] - 1
4049

@@ -146,7 +155,7 @@ def testWithSession(self):
146155
options.trace_level = config_pb2.RunOptions.FULL_TRACE
147156
run_metadata = config_pb2.RunMetadata()
148157

149-
with session.Session() as sess:
158+
with session.Session(config=no_rewrite_session_config()) as sess:
150159
a = constant_op.constant([1, 2, 3])
151160
b = constant_op.constant([2, 2, 1])
152161
result = math_ops.add(a, b)

0 commit comments

Comments
 (0)