From e2dfc382e6be58fff6ee6d0969f8925e531ac998 Mon Sep 17 00:00:00 2001 From: Hye Soo Yang Date: Fri, 29 May 2020 22:55:37 -0700 Subject: [PATCH] Fixes `assertAllEqual()` function in framework/test_util.py such that the function has the originally intended behavior without breaking PY3 compatibility. PiperOrigin-RevId: 313906279 Change-Id: I267a8310e0bad813a7dbcf28f5e14b9f4cd66203 --- tensorflow/python/framework/test_util.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 1adec3d68fd..aa52bbd8726 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -2713,8 +2713,26 @@ class TensorFlowTestCase(googletest.TestCase): x, y = a, b msgs.append("not equal lhs = %r" % x) msgs.append("not equal rhs = %r" % y) - # With Python 3, we need to make sure the dtype matches between a and b. - b = b.astype(a.dtype) + + # Handle mixed string types as a result of PY2to3 migration. That is, the + # mixing between bytes (b-prefix strings, PY2 default) and unicodes + # (u-prefix strings, PY3 default). + if six.PY3: + if (a.dtype.kind != b.dtype.kind and + {a.dtype.kind, b.dtype.kind}.issubset({"U", "S", "O"})): + a_list = [] + b_list = [] + # OK to flatten `a` and `b` because they are guaranteed to have the + # same shape. + for out_list, flat_arr in [(a_list, a.flat), (b_list, b.flat)]: + for item in flat_arr: + if isinstance(item, str): + out_list.append(item.encode("utf-8")) + else: + out_list.append(item) + a = np.array(a_list) + b = np.array(b_list) + np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs)) @py_func_if_in_function