Skip to content

Commit 698921f

Browse files
committed
Merge branch 'moe_writing_branch_from_4075ca967e5d2a78929eec99e6cc23a14fde4790'
2 parents f37afc2 + c60af56 commit 698921f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+10624
-724
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2016 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
"""Observable base class for iterables."""
17+
18+
19+
class ObservableMixin(object):
20+
"""An observable iterable.
21+
22+
Subclasses need to call self.notify_observers with any object yielded.
23+
"""
24+
25+
def __init__(self):
26+
self.observers = []
27+
28+
def register_observer(self, callback):
29+
self.observers.append(callback)
30+
31+
def notify_observers(self, value, **kwargs):
32+
for o in self.observers:
33+
o(value, **kwargs)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2016 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for the Observable mixin class."""
16+
17+
import logging
18+
import unittest
19+
20+
21+
from google.cloud.dataflow.coders import observable
22+
23+
24+
class ObservableMixinTest(unittest.TestCase):
25+
observed_count = 0
26+
observed_sum = 0
27+
observed_keys = []
28+
29+
def observer(self, value, key=None):
30+
self.observed_count += 1
31+
self.observed_sum += value
32+
self.observed_keys.append(key)
33+
34+
def test_observable(self):
35+
class Watched(observable.ObservableMixin):
36+
37+
def __iter__(self):
38+
for i in (1, 4, 3):
39+
self.notify_observers(i, key='a%d' % i)
40+
yield i
41+
42+
watched = Watched()
43+
watched.register_observer(lambda v, key: self.observer(v, key=key))
44+
for _ in watched:
45+
pass
46+
47+
self.assertEquals(3, self.observed_count)
48+
self.assertEquals(8, self.observed_sum)
49+
self.assertEquals(['a1', 'a3', 'a4'], sorted(self.observed_keys))
50+
51+
52+
if __name__ == '__main__':
53+
logging.getLogger().setLevel(logging.INFO)
54+
unittest.main()

google/cloud/dataflow/coders/slow_stream.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,33 @@ def get(self):
5959
return ''.join(self.data)
6060

6161

62+
class ByteCountingOutputStream(OutputStream):
63+
"""A pure Python implementation of stream.ByteCountingOutputStream."""
64+
65+
def __init__(self):
66+
# Note that we don't actually use any of the data initialized by our super.
67+
super(ByteCountingOutputStream, self).__init__()
68+
self.count = 0
69+
70+
def write(self, byte_array, nested=False):
71+
blen = len(byte_array)
72+
if nested:
73+
self.write_var_int64(blen)
74+
self.count += blen
75+
76+
def write_byte(self, _):
77+
self.count += 1
78+
79+
def get_count(self):
80+
return self.count
81+
82+
def get(self):
83+
raise NotImplementedError
84+
85+
def __str__(self):
86+
return '<%s %s>' % (self.__class__.__name__, self.count)
87+
88+
6289
class InputStream(object):
6390
"""A pure Python implementation of stream.InputStream."""
6491

google/cloud/dataflow/coders/stream.pxd

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@ cdef class OutputStream(object):
3232
cdef extend(self, size_t missing)
3333

3434

35+
cdef class ByteCountingOutputStream(OutputStream):
36+
cdef size_t count
37+
38+
cpdef write(self, bytes b, bint nested=*)
39+
cpdef write_byte(self, unsigned char val)
40+
cpdef write_bigendian_int64(self, libc.stdint.int64_t val)
41+
cpdef write_bigendian_int32(self, libc.stdint.int32_t val)
42+
cpdef size_t get_count(self)
43+
cpdef bytes get(self)
44+
45+
3546
cdef class InputStream(object):
3647
cdef size_t pos
3748
cdef bytes all

google/cloud/dataflow/coders/stream.pyx

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,44 @@ cdef class OutputStream(object):
9494
self.data = <char*>libc.stdlib.realloc(self.data, self.size)
9595

9696

97+
cdef class ByteCountingOutputStream(OutputStream):
98+
"""An output string stream implementation that only counts the bytes.
99+
100+
This implementation counts the number of bytes it "writes" but
101+
doesn't actually write them anyway. Thus it has write() but not
102+
get(). get_count() returns how many bytes were written.
103+
104+
This is useful for sizing an encoding.
105+
"""
106+
107+
def __cinit__(self):
108+
self.count = 0
109+
110+
cpdef write(self, bytes b, bint nested=False):
111+
cdef size_t blen = len(b)
112+
if nested:
113+
self.write_var_int64(blen)
114+
self.count += blen
115+
116+
cpdef write_byte(self, unsigned char _):
117+
self.count += 1
118+
119+
cpdef write_bigendian_int64(self, libc.stdint.int64_t _):
120+
self.count += 8
121+
122+
cpdef write_bigendian_int32(self, libc.stdint.int32_t _):
123+
self.count += 4
124+
125+
cpdef size_t get_count(self):
126+
return self.count
127+
128+
cpdef bytes get(self):
129+
raise NotImplementedError
130+
131+
def __str__(self):
132+
return '<%s %s>' % (self.__class__.__name__, self.count)
133+
134+
97135
cdef class InputStream(object):
98136
"""An input string stream implementation supporting read() and size()."""
99137

google/cloud/dataflow/coders/stream_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Tests for the stream implementations."""
1616

17+
import logging
1718
import math
1819
import unittest
1920

@@ -22,8 +23,11 @@
2223

2324

2425
class StreamTest(unittest.TestCase):
26+
# pylint: disable=invalid-name
2527
InputStream = slow_stream.InputStream
2628
OutputStream = slow_stream.OutputStream
29+
ByteCountingOutputStream = slow_stream.ByteCountingOutputStream
30+
# pylint: enable=invalid-name
2731

2832
def test_read_write(self):
2933
out_s = self.OutputStream()
@@ -99,6 +103,27 @@ def test_read_write_bigendian_int32(self):
99103
for v in values:
100104
self.assertEquals(v, in_s.read_bigendian_int32())
101105

106+
def test_byte_counting(self):
107+
bc_s = self.ByteCountingOutputStream()
108+
self.assertEquals(0, bc_s.get_count())
109+
bc_s.write('def')
110+
self.assertEquals(3, bc_s.get_count())
111+
bc_s.write('')
112+
self.assertEquals(3, bc_s.get_count())
113+
bc_s.write_byte(10)
114+
self.assertEquals(4, bc_s.get_count())
115+
# "nested" also writes the length of the string, which should
116+
# cause 1 extra byte to be counted.
117+
bc_s.write('2345', nested=True)
118+
self.assertEquals(9, bc_s.get_count())
119+
bc_s.write_var_int64(63)
120+
self.assertEquals(10, bc_s.get_count())
121+
bc_s.write_bigendian_int64(42)
122+
self.assertEquals(18, bc_s.get_count())
123+
bc_s.write_bigendian_int32(36)
124+
self.assertEquals(22, bc_s.get_count())
125+
bc_s.write_bigendian_double(6.25)
126+
self.assertEquals(30, bc_s.get_count())
102127

103128
try:
104129
# pylint: disable=g-import-not-at-top
@@ -108,22 +133,26 @@ class FastStreamTest(StreamTest):
108133
"""Runs the test with the compiled stream classes."""
109134
InputStream = stream.InputStream
110135
OutputStream = stream.OutputStream
136+
ByteCountingOutputStream = stream.ByteCountingOutputStream
111137

112138

113139
class SlowFastStreamTest(StreamTest):
114140
"""Runs the test with compiled and uncompiled stream classes."""
115141
InputStream = stream.InputStream
116142
OutputStream = slow_stream.OutputStream
143+
ByteCountingOutputStream = slow_stream.ByteCountingOutputStream
117144

118145

119146
class FastSlowStreamTest(StreamTest):
120147
"""Runs the test with uncompiled and compiled stream classes."""
121148
InputStream = slow_stream.InputStream
122149
OutputStream = stream.OutputStream
150+
ByteCountingOutputStream = stream.ByteCountingOutputStream
123151

124152
except ImportError:
125153
pass
126154

127155

128156
if __name__ == '__main__':
157+
logging.getLogger().setLevel(logging.INFO)
129158
unittest.main()

google/cloud/dataflow/dataflow_test.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,19 +250,91 @@ def match(actual):
250250
assert_that(results, matcher(1, a_list, some_pairs))
251251
pipeline.run()
252252

253-
def test_as_list_without_unique_labels(self):
254-
a_list = [1, 2, 3]
253+
def test_as_singleton_without_unique_labels(self):
254+
# This should succeed as calling AsSingleton on the same PCollection twice
255+
# with the same defaults will return the same PCollectionView.
256+
a_list = [2]
257+
pipeline = Pipeline('DirectPipelineRunner')
258+
main_input = pipeline | Create('main input', [1])
259+
side_list = pipeline | Create('side list', a_list)
260+
results = main_input | FlatMap(
261+
'test',
262+
lambda x, s1, s2: [[x, s1, s2]],
263+
AsSingleton(side_list), AsSingleton(side_list))
264+
265+
def matcher(expected_elem, expected_singleton):
266+
def match(actual):
267+
[[actual_elem, actual_singleton1, actual_singleton2]] = actual
268+
equal_to([expected_elem])([actual_elem])
269+
equal_to([expected_singleton])([actual_singleton1])
270+
equal_to([expected_singleton])([actual_singleton2])
271+
return match
272+
273+
assert_that(results, matcher(1, 2))
274+
pipeline.run()
275+
276+
def test_as_singleton_with_different_defaults_without_unique_labels(self):
277+
# This should fail as AsSingleton with distinct default values should create
278+
# distinct PCollectionViews with the same full_label.
279+
a_list = [2]
255280
pipeline = Pipeline('DirectPipelineRunner')
256281
main_input = pipeline | Create('main input', [1])
257282
side_list = pipeline | Create('side list', a_list)
283+
258284
with self.assertRaises(RuntimeError) as e:
259285
_ = main_input | FlatMap(
260286
'test',
261-
lambda x, ls1, ls2: [[x, ls1, ls2]],
262-
AsList(side_list), AsList(side_list))
287+
lambda x, s1, s2: [[x, s1, s2]],
288+
AsSingleton(side_list), AsSingleton(side_list, default_value=3))
263289
self.assertTrue(
264290
e.exception.message.startswith(
265-
'Transform "AsList" does not have a stable unique label.'))
291+
'Transform "ViewAsSingleton(side list.None)" does not have a '
292+
'stable unique label.'))
293+
294+
def test_as_singleton_with_different_defaults_with_unique_labels(self):
295+
a_list = []
296+
pipeline = Pipeline('DirectPipelineRunner')
297+
main_input = pipeline | Create('main input', [1])
298+
side_list = pipeline | Create('side list', a_list)
299+
results = main_input | FlatMap(
300+
'test',
301+
lambda x, s1, s2: [[x, s1, s2]],
302+
AsSingleton('si1', side_list, default_value=2),
303+
AsSingleton('si2', side_list, default_value=3))
304+
305+
def matcher(expected_elem, expected_singleton1, expected_singleton2):
306+
def match(actual):
307+
[[actual_elem, actual_singleton1, actual_singleton2]] = actual
308+
equal_to([expected_elem])([actual_elem])
309+
equal_to([expected_singleton1])([actual_singleton1])
310+
equal_to([expected_singleton2])([actual_singleton2])
311+
return match
312+
313+
assert_that(results, matcher(1, 2, 3))
314+
pipeline.run()
315+
316+
def test_as_list_without_unique_labels(self):
317+
# This should succeed as calling AsList on the same PCollection twice will
318+
# return the same PCollectionView.
319+
a_list = [1, 2, 3]
320+
pipeline = Pipeline('DirectPipelineRunner')
321+
main_input = pipeline | Create('main input', [1])
322+
side_list = pipeline | Create('side list', a_list)
323+
results = main_input | FlatMap(
324+
'test',
325+
lambda x, ls1, ls2: [[x, ls1, ls2]],
326+
AsList(side_list), AsList(side_list))
327+
328+
def matcher(expected_elem, expected_list):
329+
def match(actual):
330+
[[actual_elem, actual_list1, actual_list2]] = actual
331+
equal_to([expected_elem])([actual_elem])
332+
equal_to(expected_list)(actual_list1)
333+
equal_to(expected_list)(actual_list2)
334+
return match
335+
336+
assert_that(results, matcher(1, [1, 2, 3]))
337+
pipeline.run()
266338

267339
def test_as_list_with_unique_labels(self):
268340
a_list = [1, 2, 3]
@@ -282,6 +354,9 @@ def match(actual):
282354
equal_to(expected_list)(actual_list2)
283355
return match
284356

357+
assert_that(results, matcher(1, [1, 2, 3]))
358+
pipeline.run()
359+
285360
def test_as_dict_with_unique_labels(self):
286361
some_kvs = [('a', 1), ('b', 2)]
287362
pipeline = Pipeline('DirectPipelineRunner')

google/cloud/dataflow/examples/cookbook/bigquery_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def run(argv=None):
4141

4242
p = df.Pipeline(argv=pipeline_args)
4343

44-
from apitools.clients import bigquery # pylint: disable=g-import-not-at-top
44+
from google.cloud.dataflow.internal.clients import bigquery # pylint: disable=g-import-not-at-top
4545

4646
table_schema = bigquery.TableSchema()
4747

0 commit comments

Comments
 (0)