Skip to content

Commit 2e2c175

Browse files
isaprykintensorflower-gardener
authored andcommitted
Add a way to run ops using a step function to MonitoredSession.
With this method users have access to a raw Session while getting the benefit of recoverable behavior of MonitoredSession. PiperOrigin-RevId: 173334319
1 parent 171dc9f commit 2e2c175

7 files changed

Lines changed: 328 additions & 4 deletions

tensorflow/python/debug/wrappers/framework.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,10 @@ def _is_disabled_thread(self):
551551
return (self._thread_name_filter_pattern and
552552
not self._thread_name_filter_pattern.match(thread_name))
553553

554+
def run_step_fn(self, step_fn):
555+
return step_fn(
556+
monitored_session.MonitoredSession.StepContext(self._sess, self.run))
557+
554558
def partial_run_setup(self, fetches, feeds=None):
555559
"""Sets up the feeds and fetches for partial runs in the session."""
556560
raise NotImplementedError(
@@ -792,7 +796,7 @@ class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
792796

793797
def __init__(self, sess, watch_fn=None, thread_name_filter=None,
794798
pass_through_operrors=False):
795-
"""Constructor of DumpingDebugWrapperSession.
799+
"""Constructor of NonInteractiveDebugWrapperSession.
796800
797801
Args:
798802
sess: The TensorFlow `Session` object being wrapped.

tensorflow/python/training/monitored_session.py

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import six
2626

2727
from tensorflow.core.protobuf import config_pb2
28+
from tensorflow.python.estimator import util
2829
from tensorflow.python.framework import errors
2930
from tensorflow.python.framework import ops
3031
from tensorflow.python.ops import array_ops
@@ -493,6 +494,7 @@ def __init__(self, session_creator, hooks, should_recover,
493494
self._sess = _RecoverableSession(self._coordinated_creator)
494495
else:
495496
self._sess = self._coordinated_creator.create_session()
497+
self._stop_requested_in_step_fn = False
496498

497499
@property
498500
def graph(self):
@@ -520,10 +522,104 @@ def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
520522
options=options,
521523
run_metadata=run_metadata)
522524

525+
def run_step_fn(self, step_fn):
526+
"""Run ops using a step function.
527+
528+
Args:
529+
step_fn: A function or a method with a single argument of type
530+
`StepContext`. The function may use methods of the argument to
531+
perform computations with access to a raw session.
532+
533+
The returned value of the `step_fn` will be returned from `run_step_fn`,
534+
unless a stop is requested. In that case, the next `should_stop` call
535+
will return True.
536+
537+
Example usage:
538+
```python
539+
with tf.Graph().as_default():
540+
c = tf.placeholder(dtypes.float32)
541+
v = tf.add(c, 4.0)
542+
w = tf.add(c, 0.5)
543+
544+
def step_fn(step_context):
545+
a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
546+
if a <= 4.5:
547+
step_context.request_stop()
548+
return step_context.run_with_hooks(fetches=w, feed_dict={c: 0.1})
549+
550+
with tf.MonitoredSession() as session:
551+
while not session.should_stop():
552+
a = session.run_step_fn(step_fn)
553+
```
554+
Hooks interact with the `run_with_hooks()` call inside the `step_fn`
555+
as they do with a `MonitoredSession.run` call.
556+
557+
Returns:
558+
Returns the returned value of `step_fn`.
559+
560+
Raises:
561+
StopIteration: if `step_fn` has called `request_stop()`. It may be
562+
caught by `with tf.MonitoredSession()` to close the session.
563+
ValueError: if `step_fn` doesn't have a single argument called
564+
`step_context`. It may also optionally have `self` for cases when it
565+
belongs to an object.
566+
"""
567+
step_fn_arguments = util.fn_args(step_fn)
568+
if step_fn_arguments != ('step_context',) and step_fn_arguments != (
569+
'self',
570+
'step_context',
571+
):
572+
raise ValueError(
573+
'`step_fn` may either have one `step_context` argument, or'
574+
' `self` and `step_context` arguments if it\'s an instance'
575+
' method. Got {} instead.'.format(step_fn_arguments))
576+
577+
try:
578+
return step_fn(_MonitoredSession.StepContext(self._tf_sess(), self.run))
579+
except StopIteration:
580+
self._stop_requested_in_step_fn = True
581+
raise
582+
583+
class StepContext(object):
584+
"""Control flow instrument for the `step_fn` from `run_step_fn()`.
585+
586+
Users of `step_fn` may perform `run()` calls without running hooks
587+
by accessing the `session`. A `run()` call with hooks may be performed
588+
using `run_with_hooks()`. Computation flow can be interrupted using
589+
`request_stop()`.
590+
"""
591+
592+
def __init__(self, session, run_with_hooks_fn):
593+
"""Initializes the `step_context` argument for a `step_fn` invocation.
594+
595+
Args:
596+
session: An instance of `tf.Session`.
597+
run_with_hooks_fn: A function for running fetches and hooks.
598+
"""
599+
self._session = session
600+
self._run_with_hooks_fn = run_with_hooks_fn
601+
602+
@property
603+
def session(self):
604+
return self._session
605+
606+
def run_with_hooks(self, *args, **kwargs):
607+
"""Same as `MonitoredSession.run`. Accepts the same arguments."""
608+
return self._run_with_hooks_fn(*args, **kwargs)
609+
610+
def request_stop(self):
611+
"""Exit the training loop by causing `should_stop()` to return `True`.
612+
613+
Causes `step_fn` to exit by raising an exception.
614+
615+
Raises:
616+
StopIteration
617+
"""
618+
raise StopIteration('step_fn has requested the iterations to stop.')
619+
523620
def should_stop(self):
524-
if self._sess:
525-
return self._sess.should_stop()
526-
return True
621+
return (self._sess is None or self._sess.should_stop() or
622+
self._stop_requested_in_step_fn)
527623

528624
def close(self):
529625
self._close_internal()

tensorflow/python/training/monitored_session_test.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@
3333
from tensorflow.core.protobuf import debug_pb2
3434
from tensorflow.python.client import session as session_lib
3535
from tensorflow.python.framework import constant_op
36+
from tensorflow.python.framework import dtypes
3637
from tensorflow.python.framework import errors_impl
3738
from tensorflow.python.framework import ops
3839
from tensorflow.python.ops import array_ops
3940
from tensorflow.python.ops import control_flow_ops
41+
from tensorflow.python.ops import resource_variable_ops
4042
from tensorflow.python.ops import state_ops
4143
from tensorflow.python.ops import variables
4244
from tensorflow.python.platform import test
@@ -1449,6 +1451,170 @@ def test_with_statement_and_close(self):
14491451
with monitored_session.MonitoredSession() as session:
14501452
session.close()
14511453

1454+
def test_step_fn_example(self):
1455+
with ops.Graph().as_default():
1456+
c = array_ops.placeholder(dtypes.float32)
1457+
v = array_ops.identity(c)
1458+
1459+
def step_fn(step_context):
1460+
value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
1461+
return value
1462+
1463+
with monitored_session.MonitoredSession() as session:
1464+
self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
1465+
1466+
def test_step_function_stops(self):
1467+
with ops.Graph().as_default():
1468+
1469+
def step_fn(step_context):
1470+
step_context.request_stop()
1471+
1472+
with monitored_session.MonitoredSession() as session:
1473+
self.assertEqual(None, session.run_step_fn(step_fn))
1474+
self.assertTrue(session.should_stop())
1475+
1476+
def test_step_request_stop_without_a_with_block(self):
1477+
with ops.Graph().as_default():
1478+
1479+
def step_fn(step_context):
1480+
step_context.request_stop()
1481+
1482+
session = monitored_session.MonitoredSession()
1483+
try:
1484+
self.assertEqual(None, session.run_step_fn(step_fn))
1485+
except StopIteration:
1486+
pass
1487+
self.assertTrue(session.should_stop())
1488+
1489+
def test_step_request_stop_in_a_loop(self):
1490+
with ops.Graph().as_default():
1491+
def step_fn(step_context):
1492+
step_context.request_stop()
1493+
1494+
with monitored_session.MonitoredSession() as session:
1495+
while not session.should_stop():
1496+
_ = session.run_step_fn(step_fn)
1497+
self.fail('An exception should be raised on the line above.')
1498+
1499+
def test_step_request_stop_with_returning_a_type(self):
1500+
with ops.Graph().as_default():
1501+
1502+
def step_fn(step_context):
1503+
del step_context
1504+
return 'a type'
1505+
1506+
with monitored_session.MonitoredSession() as session:
1507+
self.assertEqual('a type', session.run_step_fn(step_fn))
1508+
1509+
def test_step_with_extra_arguments(self):
1510+
with ops.Graph().as_default():
1511+
1512+
def step_fn(step_context, extra_foo):
1513+
del step_context, extra_foo
1514+
1515+
with monitored_session.MonitoredSession() as session:
1516+
with self.assertRaisesRegexp(
1517+
ValueError,
1518+
'`step_fn` may either have one `step_context` argument'):
1519+
self.assertEqual(None, session.run_step_fn(step_fn))
1520+
1521+
def test_step_fn_belongs_to_a_class(self):
1522+
with ops.Graph().as_default():
1523+
c = array_ops.placeholder(dtypes.float32)
1524+
v = array_ops.identity(c)
1525+
1526+
class Model(object):
1527+
1528+
def step_fn(self, step_context):
1529+
value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
1530+
return value
1531+
1532+
with monitored_session.MonitoredSession() as session:
1533+
model = Model()
1534+
self.assertNear(3.2, session.run_step_fn(model.step_fn), 0.1)
1535+
1536+
def test_step_fn_belongs_to_a_class_and_has_extra_methods(self):
1537+
with ops.Graph().as_default():
1538+
1539+
class Model(object):
1540+
1541+
def step_fn(self, step_context, extra_foo):
1542+
del step_context, extra_foo
1543+
1544+
with monitored_session.MonitoredSession() as session:
1545+
with self.assertRaisesRegexp(
1546+
ValueError,
1547+
'`step_fn` may either have one `step_context` argument'):
1548+
model = Model()
1549+
self.assertEqual(None, session.run_step_fn(model.step_fn))
1550+
1551+
def test_step_fn_with_hooks(self):
1552+
with ops.Graph().as_default():
1553+
var = resource_variable_ops.ResourceVariable(0.0)
1554+
1555+
# This test higlights the interaction of hooks with
1556+
# `Monitoredsession.run_step_fn`. The order of execution of operations
1557+
# below is:
1558+
# 0. stage_0
1559+
# 1. stage_1_0 or stage_1_1 in an undefined order
1560+
# 2. stage_2
1561+
1562+
stage_0 = state_ops.assign_add(var, 0.3)
1563+
stage_1_0 = state_ops.assign_add(var, 0.7)
1564+
# The order of `stage_1_0` and `stage_1_1` is undefined by
1565+
# `MonitoredSession`, but we should be able to assert when both of them
1566+
# are complete. To obtain a consistent result of adding two different
1567+
# constants to `var`, we rely on a control dependency and
1568+
# `ResourceVariable`. Otherwise, it is possible that one of the
1569+
# additions overwites the result of the other addition.
1570+
with ops.control_dependencies([stage_1_0]):
1571+
stage_1_1 = state_ops.assign_add(var, 0.5)
1572+
stage_2 = state_ops.assign_add(var, 1.1)
1573+
1574+
class Hook(session_run_hook.SessionRunHook):
1575+
1576+
def __init__(self, testing):
1577+
self._testing = testing
1578+
1579+
def before_run(self, run_context):
1580+
return session_run_hook.SessionRunArgs(fetches=stage_1_0)
1581+
1582+
def after_run(self, run_context, run_values):
1583+
self._testing.assertNear(0.3 + 0.5 + 0.7,
1584+
run_context.session.run(var), 0.1)
1585+
self._testing.assertNear(0.3 + 0.5 + 0.7 + 1.1,
1586+
run_context.session.run(stage_2), 0.1)
1587+
1588+
def step_fn(step_context):
1589+
self.assertNear(0.3, step_context.session.run(stage_0), 0.1)
1590+
return step_context.run_with_hooks(fetches=stage_1_1)
1591+
1592+
with monitored_session.MonitoredSession(hooks=[Hook(self)]) as session:
1593+
self.assertEqual(0.3 + 0.5 + 0.7, session.run_step_fn(step_fn))
1594+
1595+
def test_step_fn_with_hooks_and_request_stop(self):
1596+
with ops.Graph().as_default():
1597+
trace_the_hook = {'before_run': False, 'after_run': False}
1598+
1599+
class Hook(session_run_hook.SessionRunHook):
1600+
1601+
def before_run(self, run_context):
1602+
trace_the_hook['before_run'] = True
1603+
1604+
def after_run(self, run_context, run_values):
1605+
trace_the_hook['after_run'] = True
1606+
1607+
def step_fn(step_context):
1608+
step_context.request_stop()
1609+
1610+
with monitored_session.MonitoredSession(hooks=[Hook()]) as session:
1611+
self.assertEqual(None, session.run_step_fn(step_fn))
1612+
self.assertTrue(session.should_stop())
1613+
# `step_context.request_stop()` in a step_fn interrupts the flow of
1614+
# running the hooks.
1615+
self.assertFalse(trace_the_hook['before_run'])
1616+
self.assertFalse(trace_the_hook['after_run'])
1617+
14521618

14531619
class SingularMonitoredSessionTest(test.TestCase):
14541620
"""Tests SingularMonitoredSession."""
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
path: "tensorflow.train.MonitoredSession.StepContext"
2+
tf_class {
3+
is_instance: "<class \'tensorflow.python.training.monitored_session.StepContext\'>"
4+
is_instance: "<type \'object\'>"
5+
member {
6+
name: "session"
7+
mtype: "<type \'property\'>"
8+
}
9+
member_method {
10+
name: "__init__"
11+
argspec: "args=[\'self\', \'session\', \'run_with_hooks_fn\'], varargs=None, keywords=None, defaults=None"
12+
}
13+
member_method {
14+
name: "request_stop"
15+
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
16+
}
17+
member_method {
18+
name: "run_with_hooks"
19+
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
20+
}
21+
}

tensorflow/tools/api/golden/tensorflow.train.-monitored-session.pbtxt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ tf_class {
33
is_instance: "<class \'tensorflow.python.training.monitored_session.MonitoredSession\'>"
44
is_instance: "<class \'tensorflow.python.training.monitored_session._MonitoredSession\'>"
55
is_instance: "<type \'object\'>"
6+
member {
7+
name: "StepContext"
8+
mtype: "<type \'type\'>"
9+
}
610
member {
711
name: "graph"
812
mtype: "<type \'property\'>"
@@ -19,6 +23,10 @@ tf_class {
1923
name: "run"
2024
argspec: "args=[\'self\', \'fetches\', \'feed_dict\', \'options\', \'run_metadata\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
2125
}
26+
member_method {
27+
name: "run_step_fn"
28+
argspec: "args=[\'self\', \'step_fn\'], varargs=None, keywords=None, defaults=None"
29+
}
2230
member_method {
2331
name: "should_stop"
2432
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
path: "tensorflow.train.SingularMonitoredSession.StepContext"
2+
tf_class {
3+
is_instance: "<class \'tensorflow.python.training.monitored_session.StepContext\'>"
4+
is_instance: "<type \'object\'>"
5+
member {
6+
name: "session"
7+
mtype: "<type \'property\'>"
8+
}
9+
member_method {
10+
name: "__init__"
11+
argspec: "args=[\'self\', \'session\', \'run_with_hooks_fn\'], varargs=None, keywords=None, defaults=None"
12+
}
13+
member_method {
14+
name: "request_stop"
15+
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
16+
}
17+
member_method {
18+
name: "run_with_hooks"
19+
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
20+
}
21+
}

0 commit comments

Comments
 (0)