diff --git a/caffe2/python/data_parallel_model.py b/caffe2/python/data_parallel_model.py index 8a4a79f4034..c95bb8f0fd1 100644 --- a/caffe2/python/data_parallel_model.py +++ b/caffe2/python/data_parallel_model.py @@ -259,25 +259,27 @@ def Parallelize_GPU_BMUF( master_gpu_opt = core.DeviceOption(caffe2_pb2.CUDA, master_gpu) num_workers = len(devices) + num_worker_threads = 4 * len(devices) loss_scale = 1.0 / num_workers if block_momentum is None: block_momentum = 1.0 - 1.0 / num_workers - model_helper_obj.net.Proto().num_workers = num_workers + model_helper_obj.net.Proto().num_workers = num_worker_threads model_helper_obj.net.Proto().type = net_type # A net for initializing global model parameters. Its called once in the # same step as net parameters initialization. model_helper_obj._global_model_init_net = core.Net('global_model_init') model_helper_obj._global_model_init_net.Proto().type = net_type - model_helper_obj._global_model_init_net.Proto().num_workers = num_workers + model_helper_obj._global_model_init_net.Proto().num_workers = \ + num_worker_threads # A net for computing final parameter updates. Its will run once after # running net (local models updates) for `num_local_iterations` times. model_helper_obj._global_model_param_updates_net = core.Net('global_model') model_helper_obj._global_model_param_updates_net.Proto().type = net_type model_helper_obj._global_model_param_updates_net.Proto().num_workers = \ - num_workers + num_worker_threads def _v(param): return "{}_v".format(param)