Skip to content

Commit fbb5392

Browse files
Merged commit includes the following changes:
179221620 by akshayka: Internal cleanup: Delete extraneous print statement in test case. -- 179220917 by A. Unique TensorFlower: [XLA:JF] Make HLO parser recognize negative padding. -- PiperOrigin-RevId: 179221620
1 parent 8d3690c commit fbb5392

7 files changed

Lines changed: 29 additions & 8 deletions

File tree

tensorflow/compiler/xla/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,12 @@ filegroup(
641641
visibility = ["//tensorflow:__subpackages__"],
642642
)
643643

644+
py_proto_library(
645+
name = "xla_data_proto_py_pb2",
646+
api_version = 2,
647+
deps = [":xla_data_proto"],
648+
)
649+
644650
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.
645651
cc_header_only_library(
646652
name = "xla_headers_lib",

tensorflow/compiler/xla/python/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ py_library(
1111
visibility = ["//visibility:public"],
1212
deps = [
1313
":pywrap_xla",
14-
"//tensorflow/compiler/xla:xla_data_proto_py",
14+
"//tensorflow/compiler/xla:xla_data_proto_py_pb2",
1515
],
1616
)
1717

@@ -23,6 +23,7 @@ py_test(
2323
deps = [
2424
":xla_client",
2525
"//tensorflow/python:platform_test",
26+
"//third_party/py/numpy",
2627
],
2728
)
2829

@@ -51,6 +52,7 @@ cc_library(
5152
"//tensorflow/compiler/xla/client:local_client",
5253
"//tensorflow/compiler/xla/service:cpu_plugin",
5354
"//tensorflow/core:lib",
55+
"//tensorflow/stream_executor/host:host_platform",
5456
],
5557
)
5658

tensorflow/compiler/xla/python/local_computation_builder.i

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ limitations under the License.
106106

107107
#include "tensorflow/compiler/xla/literal_util.h"
108108
#include "tensorflow/compiler/xla/shape_util.h"
109-
#include "tensorflow/compiler/xla/xla_data.pb.h"
109+
#include "tensorflow/compiler/xla/xla_data.proto.h"
110110
#include "tensorflow/core/lib/gtl/array_slice.h"
111111
#include "tensorflow/compiler/xla/python/numpy_bridge.h"
112112
#include "tensorflow/compiler/xla/python/local_computation_builder.h"

tensorflow/compiler/xla/python/xla_client_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
import numpy as np
2424

2525
from tensorflow.compiler.xla.python import xla_client
26-
import unittest
26+
from tensorflow.python.platform import googletest
2727

2828

29-
class LocalComputationTest(unittest.TestCase):
29+
class LocalComputationTest(googletest.TestCase):
3030
"""Base class for running an XLA Computation through the local client."""
3131

3232
def _NewComputation(self, name=None):
@@ -895,4 +895,4 @@ def testWhileF64(self):
895895

896896

897897
if __name__ == "__main__":
898-
unittest.main()
898+
googletest.main()

tensorflow/compiler/xla/tools/parser/hlo_lexer.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ TokKind HloLexer::LexPercent() {
257257
// fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
258258
// dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,}
259259
// dxd_pattern ::= [0-9]+(x[0-9]+)+
260-
// pad_pattern ::= [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*
260+
// pad_pattern ::=
261+
// [-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*
261262
// int ::= [-]?[0-9]+
262263
// negative inf ::= '-inf'
263264
TokKind HloLexer::LexNumberOrPattern() {
@@ -275,7 +276,7 @@ TokKind HloLexer::LexNumberOrPattern() {
275276
R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"};
276277
static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"};
277278
static LazyRE2 pad_pattern = {
278-
R"([0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*)"};
279+
R"([-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*)"};
279280

280281
if (RE2::Consume(&consumable, *dim_labels_pattern)) {
281282
current_ptr_ = consumable.begin();

tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,19 @@ ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] {
606606
ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0
607607
}
608608
609+
)"
610+
},
611+
// Negative padding
612+
{
613+
"PadHasNegativePadding",
614+
R"(HloModule PadHasNegativePadding_module
615+
616+
ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,29] {
617+
%input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0)
618+
%constant = f32[] constant(-5.123)
619+
ROOT %pad = f32[1,15,6,3,29]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3
620+
}
621+
609622
)"
610623
},
611624
// fusion

tensorflow/python/framework/ops_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2182,7 +2182,6 @@ def _get_test_attrs(self):
21822182
b = compat.as_text(x.get_attr("_B"))
21832183
except ValueError:
21842184
b = None
2185-
print(a, b)
21862185
return (a, b)
21872186

21882187
def testNoLabel(self):

0 commit comments

Comments
 (0)