|
33 | 33 | from tensorflow.core.protobuf import debug_pb2 |
34 | 34 | from tensorflow.python.client import session as session_lib |
35 | 35 | from tensorflow.python.framework import constant_op |
| 36 | +from tensorflow.python.framework import dtypes |
36 | 37 | from tensorflow.python.framework import errors_impl |
37 | 38 | from tensorflow.python.framework import ops |
38 | 39 | from tensorflow.python.ops import array_ops |
39 | 40 | from tensorflow.python.ops import control_flow_ops |
| 41 | +from tensorflow.python.ops import resource_variable_ops |
40 | 42 | from tensorflow.python.ops import state_ops |
41 | 43 | from tensorflow.python.ops import variables |
42 | 44 | from tensorflow.python.platform import test |
@@ -1449,6 +1451,170 @@ def test_with_statement_and_close(self): |
1449 | 1451 | with monitored_session.MonitoredSession() as session: |
1450 | 1452 | session.close() |
1451 | 1453 |
|
| 1454 | + def test_step_fn_example(self): |
| 1455 | + with ops.Graph().as_default(): |
| 1456 | + c = array_ops.placeholder(dtypes.float32) |
| 1457 | + v = array_ops.identity(c) |
| 1458 | + |
| 1459 | + def step_fn(step_context): |
| 1460 | + value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2}) |
| 1461 | + return value |
| 1462 | + |
| 1463 | + with monitored_session.MonitoredSession() as session: |
| 1464 | + self.assertNear(3.2, session.run_step_fn(step_fn), 0.1) |
| 1465 | + |
| 1466 | + def test_step_function_stops(self): |
| 1467 | + with ops.Graph().as_default(): |
| 1468 | + |
| 1469 | + def step_fn(step_context): |
| 1470 | + step_context.request_stop() |
| 1471 | + |
| 1472 | + with monitored_session.MonitoredSession() as session: |
| 1473 | + self.assertEqual(None, session.run_step_fn(step_fn)) |
| 1474 | + self.assertTrue(session.should_stop()) |
| 1475 | + |
| 1476 | + def test_step_request_stop_without_a_with_block(self): |
| 1477 | + with ops.Graph().as_default(): |
| 1478 | + |
| 1479 | + def step_fn(step_context): |
| 1480 | + step_context.request_stop() |
| 1481 | + |
| 1482 | + session = monitored_session.MonitoredSession() |
| 1483 | + try: |
| 1484 | + self.assertEqual(None, session.run_step_fn(step_fn)) |
| 1485 | + except StopIteration: |
| 1486 | + pass |
| 1487 | + self.assertTrue(session.should_stop()) |
| 1488 | + |
| 1489 | + def test_step_request_stop_in_a_loop(self): |
| 1490 | + with ops.Graph().as_default(): |
| 1491 | + def step_fn(step_context): |
| 1492 | + step_context.request_stop() |
| 1493 | + |
| 1494 | + with monitored_session.MonitoredSession() as session: |
| 1495 | + while not session.should_stop(): |
| 1496 | + _ = session.run_step_fn(step_fn) |
| 1497 | + self.fail('An exception should be raised on the line above.') |
| 1498 | + |
| 1499 | + def test_step_request_stop_with_returning_a_type(self): |
| 1500 | + with ops.Graph().as_default(): |
| 1501 | + |
| 1502 | + def step_fn(step_context): |
| 1503 | + del step_context |
| 1504 | + return 'a type' |
| 1505 | + |
| 1506 | + with monitored_session.MonitoredSession() as session: |
| 1507 | + self.assertEqual('a type', session.run_step_fn(step_fn)) |
| 1508 | + |
| 1509 | + def test_step_with_extra_arguments(self): |
| 1510 | + with ops.Graph().as_default(): |
| 1511 | + |
| 1512 | + def step_fn(step_context, extra_foo): |
| 1513 | + del step_context, extra_foo |
| 1514 | + |
| 1515 | + with monitored_session.MonitoredSession() as session: |
| 1516 | + with self.assertRaisesRegexp( |
| 1517 | + ValueError, |
| 1518 | + '`step_fn` may either have one `step_context` argument'): |
| 1519 | + self.assertEqual(None, session.run_step_fn(step_fn)) |
| 1520 | + |
| 1521 | + def test_step_fn_belongs_to_a_class(self): |
| 1522 | + with ops.Graph().as_default(): |
| 1523 | + c = array_ops.placeholder(dtypes.float32) |
| 1524 | + v = array_ops.identity(c) |
| 1525 | + |
| 1526 | + class Model(object): |
| 1527 | + |
| 1528 | + def step_fn(self, step_context): |
| 1529 | + value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2}) |
| 1530 | + return value |
| 1531 | + |
| 1532 | + with monitored_session.MonitoredSession() as session: |
| 1533 | + model = Model() |
| 1534 | + self.assertNear(3.2, session.run_step_fn(model.step_fn), 0.1) |
| 1535 | + |
| 1536 | + def test_step_fn_belongs_to_a_class_and_has_extra_methods(self): |
| 1537 | + with ops.Graph().as_default(): |
| 1538 | + |
| 1539 | + class Model(object): |
| 1540 | + |
| 1541 | + def step_fn(self, step_context, extra_foo): |
| 1542 | + del step_context, extra_foo |
| 1543 | + |
| 1544 | + with monitored_session.MonitoredSession() as session: |
| 1545 | + with self.assertRaisesRegexp( |
| 1546 | + ValueError, |
| 1547 | + '`step_fn` may either have one `step_context` argument'): |
| 1548 | + model = Model() |
| 1549 | + self.assertEqual(None, session.run_step_fn(model.step_fn)) |
| 1550 | + |
| 1551 | + def test_step_fn_with_hooks(self): |
| 1552 | + with ops.Graph().as_default(): |
| 1553 | + var = resource_variable_ops.ResourceVariable(0.0) |
| 1554 | + |
| 1555 | + # This test higlights the interaction of hooks with |
| 1556 | + # `Monitoredsession.run_step_fn`. The order of execution of operations |
| 1557 | + # below is: |
| 1558 | + # 0. stage_0 |
| 1559 | + # 1. stage_1_0 or stage_1_1 in an undefined order |
| 1560 | + # 2. stage_2 |
| 1561 | + |
| 1562 | + stage_0 = state_ops.assign_add(var, 0.3) |
| 1563 | + stage_1_0 = state_ops.assign_add(var, 0.7) |
| 1564 | + # The order of `stage_1_0` and `stage_1_1` is undefined by |
| 1565 | + # `MonitoredSession`, but we should be able to assert when both of them |
| 1566 | + # are complete. To obtain a consistent result of adding two different |
| 1567 | + # constants to `var`, we rely on a control dependency and |
| 1568 | + # `ResourceVariable`. Otherwise, it is possible that one of the |
| 1569 | + # additions overwites the result of the other addition. |
| 1570 | + with ops.control_dependencies([stage_1_0]): |
| 1571 | + stage_1_1 = state_ops.assign_add(var, 0.5) |
| 1572 | + stage_2 = state_ops.assign_add(var, 1.1) |
| 1573 | + |
| 1574 | + class Hook(session_run_hook.SessionRunHook): |
| 1575 | + |
| 1576 | + def __init__(self, testing): |
| 1577 | + self._testing = testing |
| 1578 | + |
| 1579 | + def before_run(self, run_context): |
| 1580 | + return session_run_hook.SessionRunArgs(fetches=stage_1_0) |
| 1581 | + |
| 1582 | + def after_run(self, run_context, run_values): |
| 1583 | + self._testing.assertNear(0.3 + 0.5 + 0.7, |
| 1584 | + run_context.session.run(var), 0.1) |
| 1585 | + self._testing.assertNear(0.3 + 0.5 + 0.7 + 1.1, |
| 1586 | + run_context.session.run(stage_2), 0.1) |
| 1587 | + |
| 1588 | + def step_fn(step_context): |
| 1589 | + self.assertNear(0.3, step_context.session.run(stage_0), 0.1) |
| 1590 | + return step_context.run_with_hooks(fetches=stage_1_1) |
| 1591 | + |
| 1592 | + with monitored_session.MonitoredSession(hooks=[Hook(self)]) as session: |
| 1593 | + self.assertEqual(0.3 + 0.5 + 0.7, session.run_step_fn(step_fn)) |
| 1594 | + |
| 1595 | + def test_step_fn_with_hooks_and_request_stop(self): |
| 1596 | + with ops.Graph().as_default(): |
| 1597 | + trace_the_hook = {'before_run': False, 'after_run': False} |
| 1598 | + |
| 1599 | + class Hook(session_run_hook.SessionRunHook): |
| 1600 | + |
| 1601 | + def before_run(self, run_context): |
| 1602 | + trace_the_hook['before_run'] = True |
| 1603 | + |
| 1604 | + def after_run(self, run_context, run_values): |
| 1605 | + trace_the_hook['after_run'] = True |
| 1606 | + |
| 1607 | + def step_fn(step_context): |
| 1608 | + step_context.request_stop() |
| 1609 | + |
| 1610 | + with monitored_session.MonitoredSession(hooks=[Hook()]) as session: |
| 1611 | + self.assertEqual(None, session.run_step_fn(step_fn)) |
| 1612 | + self.assertTrue(session.should_stop()) |
| 1613 | + # `step_context.request_stop()` in a step_fn interrupts the flow of |
| 1614 | + # running the hooks. |
| 1615 | + self.assertFalse(trace_the_hook['before_run']) |
| 1616 | + self.assertFalse(trace_the_hook['after_run']) |
| 1617 | + |
1452 | 1618 |
|
1453 | 1619 | class SingularMonitoredSessionTest(test.TestCase): |
1454 | 1620 | """Tests SingularMonitoredSession.""" |
|
0 commit comments