mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Fixes assertAllEqual() function in framework/test_util.py such that the function has the originally intended behavior without breaking PY3 compatibility.
PiperOrigin-RevId: 313906279 Change-Id: I267a8310e0bad813a7dbcf28f5e14b9f4cd66203
This commit is contained in:
committed by
TensorFlower Gardener
parent
da8af6adef
commit
e2dfc382e6
@@ -2713,8 +2713,26 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
x, y = a, b
|
||||
msgs.append("not equal lhs = %r" % x)
|
||||
msgs.append("not equal rhs = %r" % y)
|
||||
# With Python 3, we need to make sure the dtype matches between a and b.
|
||||
b = b.astype(a.dtype)
|
||||
|
||||
# Handle mixed string types as a result of PY2to3 migration. That is, the
|
||||
# mixing between bytes (b-prefix strings, PY2 default) and unicodes
|
||||
# (u-prefix strings, PY3 default).
|
||||
if six.PY3:
|
||||
if (a.dtype.kind != b.dtype.kind and
|
||||
{a.dtype.kind, b.dtype.kind}.issubset({"U", "S", "O"})):
|
||||
a_list = []
|
||||
b_list = []
|
||||
# OK to flatten `a` and `b` because they are guaranteed to have the
|
||||
# same shape.
|
||||
for out_list, flat_arr in [(a_list, a.flat), (b_list, b.flat)]:
|
||||
for item in flat_arr:
|
||||
if isinstance(item, str):
|
||||
out_list.append(item.encode("utf-8"))
|
||||
else:
|
||||
out_list.append(item)
|
||||
a = np.array(a_list)
|
||||
b = np.array(b_list)
|
||||
|
||||
np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))
|
||||
|
||||
@py_func_if_in_function
|
||||
|
||||
Reference in New Issue
Block a user