diff --git a/caffe2/python/utils.py b/caffe2/python/utils.py index 46ea27c3341..75d141c7557 100644 --- a/caffe2/python/utils.py +++ b/caffe2/python/utils.py @@ -357,7 +357,10 @@ def BuildUniqueMutexIter( from caffe2.python import core if not init_net.BlobIsDefined(iter): # Add training operators. - with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)): + with core.DeviceScope( + core.DeviceOption(caffe2_pb2.CPU, + extra_info=["device_type_override:cpu"]) + ): iteration = init_net.ConstantFill( [], iter, diff --git a/caffe2/python/utils_test.py b/caffe2/python/utils_test.py index a1fa9a1fe88..3921f3d67ca 100644 --- a/caffe2/python/utils_test.py +++ b/caffe2/python/utils_test.py @@ -3,7 +3,7 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals -from caffe2.python import utils, test_util +from caffe2.python import core, utils, test_util import numpy as np @@ -25,3 +25,16 @@ class TestUtils(test_util.TestCase): "stringlist1" : [b"foo", b"bar"]} self.assertEqual(dict_, expected, "dictionary version of arguments " "doesn't match original") + + def testBuildUniqueMutexIter(self): + init_net = core.Net("init_net") + net = core.Net("net") + utils.BuildUniqueMutexIter(init_net, net) + + for op in init_net.Proto().op: + self.assertEqual(op.device_option.extra_info[0], + "device_type_override:cpu") + + for op in net.Proto().op: + self.assertEqual(op.device_option.extra_info[0], + "device_type_override:cpu")