Fixes assertAllEqual() function in framework/test_util.py such that the function has the originally intended behavior without breaking PY3 compatibility.

PiperOrigin-RevId: 313906279
Change-Id: I267a8310e0bad813a7dbcf28f5e14b9f4cd66203
This commit is contained in:
Hye Soo Yang
2020-05-29 22:55:37 -07:00
committed by TensorFlower Gardener
parent da8af6adef
commit e2dfc382e6

View File

@@ -2713,8 +2713,26 @@ class TensorFlowTestCase(googletest.TestCase):
x, y = a, b
msgs.append("not equal lhs = %r" % x)
msgs.append("not equal rhs = %r" % y)
# With Python 3, we need to make sure the dtype matches between a and b.
b = b.astype(a.dtype)
# Handle mixed string types as a result of PY2to3 migration. That is, the
# mixing between bytes (b-prefix strings, PY2 default) and unicodes
# (u-prefix strings, PY3 default).
if six.PY3:
if (a.dtype.kind != b.dtype.kind and
{a.dtype.kind, b.dtype.kind}.issubset({"U", "S", "O"})):
a_list = []
b_list = []
# OK to flatten `a` and `b` because they are guaranteed to have the
# same shape.
for out_list, flat_arr in [(a_list, a.flat), (b_list, b.flat)]:
for item in flat_arr:
if isinstance(item, str):
out_list.append(item.encode("utf-8"))
else:
out_list.append(item)
a = np.array(a_list)
b = np.array(b_list)
np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))
@py_func_if_in_function