mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Revert "Fix mesh.get_local_rank when it is > 1d (#164473)"
This reverts commit 83d71dfb2f.
Reverted https://github.com/pytorch/pytorch/pull/164473 on behalf of https://github.com/izaitsevfb due to appears to be causing vision_maskrcnn regression ([comment](https://github.com/pytorch/pytorch/pull/164473#issuecomment-3374738997))
This commit is contained in:
@@ -378,7 +378,7 @@ vgg16,pass,0
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,18
|
||||
vision_maskrcnn,pass,20
|
||||
|
||||
|
||||
|
||||
|
||||
|
@@ -286,7 +286,7 @@ vgg16,pass,6
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,37
|
||||
vision_maskrcnn,pass,39
|
||||
|
||||
|
||||
|
||||
|
||||
|
@@ -271,11 +271,7 @@ class DeviceMeshVariable(DistributedVariable):
|
||||
if name == "get_rank":
|
||||
return ConstantVariable.create(self.value.get_rank())
|
||||
if name == "get_local_rank":
|
||||
const_args = [x.as_python_constant() for x in args]
|
||||
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||
return ConstantVariable.create(
|
||||
self.value.get_local_rank(*const_args, **const_kwargs)
|
||||
)
|
||||
return ConstantVariable.create(self.value.get_local_rank())
|
||||
if name == "get_group":
|
||||
const_args = [x.as_python_constant() for x in args]
|
||||
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||
|
||||
Reference in New Issue
Block a user