mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[C2] Introduce extra_info force CPU tags for auto-generated iteration counter blobs (#32607)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32607 As desc. Test Plan: Unit-test. Reviewed By: xw285cornell, chocjy Differential Revision: D19551567 fbshipit-source-id: 3a121351d2b4016e99a1536dec746be970698664
This commit is contained in:
committed by
Facebook Github Bot
parent
3c17cbb6c8
commit
e76fa9822d
@@ -357,7 +357,10 @@ def BuildUniqueMutexIter(
|
||||
from caffe2.python import core
|
||||
if not init_net.BlobIsDefined(iter):
|
||||
# Add training operators.
|
||||
with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
|
||||
with core.DeviceScope(
|
||||
core.DeviceOption(caffe2_pb2.CPU,
|
||||
extra_info=["device_type_override:cpu"])
|
||||
):
|
||||
iteration = init_net.ConstantFill(
|
||||
[],
|
||||
iter,
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from caffe2.python import utils, test_util
|
||||
from caffe2.python import core, utils, test_util
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -25,3 +25,16 @@ class TestUtils(test_util.TestCase):
|
||||
"stringlist1" : [b"foo", b"bar"]}
|
||||
self.assertEqual(dict_, expected, "dictionary version of arguments "
|
||||
"doesn't match original")
|
||||
|
||||
def testBuildUniqueMutexIter(self):
|
||||
init_net = core.Net("init_net")
|
||||
net = core.Net("net")
|
||||
utils.BuildUniqueMutexIter(init_net, net)
|
||||
|
||||
for op in init_net.Proto().op:
|
||||
self.assertEqual(op.device_option.extra_info[0],
|
||||
"device_type_override:cpu")
|
||||
|
||||
for op in net.Proto().op:
|
||||
self.assertEqual(op.device_option.extra_info[0],
|
||||
"device_type_override:cpu")
|
||||
|
||||
Reference in New Issue
Block a user