mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Register RoIAlignRotated with C10
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30785 Reviewed By: wat3rBro Differential Revision: D18415056 fbshipit-source-id: e00376bec948309d53f2172697cd477449f769b2
This commit is contained in:
committed by
Facebook Github Bot
parent
b79030d6c8
commit
ef5ae4823a
@@ -403,3 +403,22 @@ Based on https://arxiv.org/abs/1703.01086.
|
||||
"is a pooled feature map cooresponding to the r-th RoI.");
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
using RoIAlignRotatedOpFloatCPU =
|
||||
caffe2::RoIAlignRotatedOp<float, caffe2::CPUContext>;
|
||||
|
||||
// clang-format off
|
||||
C10_EXPORT_CAFFE2_OP_TO_C10_CPU(
|
||||
RoIAlignRotated,
|
||||
"_caffe2::RoIAlignRotated("
|
||||
"Tensor features, "
|
||||
"Tensor rois, "
|
||||
"str order, "
|
||||
"float spatial_scale, "
|
||||
"int pooled_h, "
|
||||
"int pooled_w, "
|
||||
"int sampling_ratio, "
|
||||
"bool aligned"
|
||||
") -> Tensor",
|
||||
RoIAlignRotatedOpFloatCPU);
|
||||
// clang-format on
|
||||
|
||||
@@ -206,3 +206,8 @@ bool RoIAlignRotatedOp<float, CUDAContext>::RunOnDevice() {
|
||||
|
||||
REGISTER_CUDA_OPERATOR(RoIAlignRotated, RoIAlignRotatedOp<float, CUDAContext>);
|
||||
} // namespace caffe2
|
||||
|
||||
using RoIAlignRotatedOpFloatCUDA =
|
||||
caffe2::RoIAlignRotatedOp<float, caffe2::CUDAContext>;
|
||||
|
||||
C10_EXPORT_CAFFE2_OP_TO_C10_CUDA(RoIAlignRotated, RoIAlignRotatedOpFloatCUDA);
|
||||
|
||||
@@ -4,9 +4,12 @@
|
||||
#define ROTATED_ROI_ALIGN_OP_H_
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/export_caffe2_op_to_c10.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(RoIAlignRotated)
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <typename T, class Context>
|
||||
|
||||
@@ -511,6 +511,63 @@ class TorchIntegration(hu.HypothesisTestCase):
|
||||
def test_roi_align_cuda(self):
|
||||
self._test_roi_align(device="cuda")
|
||||
|
||||
@given(
|
||||
N=st.integers(min_value=1, max_value=2),
|
||||
C=st.integers(min_value=4, max_value=4),
|
||||
H=st.integers(min_value=10, max_value=10),
|
||||
W=st.integers(min_value=8, max_value=8),
|
||||
)
|
||||
def _test_roi_align_rotated(self, N, C, H, W, device):
|
||||
def rand_rotated_roi():
|
||||
return np.array(
|
||||
[
|
||||
float(int(N * np.random.rand())),
|
||||
np.random.rand() * W,
|
||||
np.random.rand() * H,
|
||||
np.random.rand() * W,
|
||||
np.random.rand() * H,
|
||||
np.random.rand() * 360 - 180
|
||||
]
|
||||
).astype(np.float32)
|
||||
|
||||
feature = np.random.randn(N, C, H, W).astype(np.float32)
|
||||
rois = np.array([rand_rotated_roi() for _ in range(10)])
|
||||
|
||||
def roi_align_ref(_feature, _rois):
|
||||
ref_op = core.CreateOperator(
|
||||
"RoIAlignRotated",
|
||||
["feature", "rois"],
|
||||
["roi_feature"],
|
||||
spatial_scale=1.0,
|
||||
pooled_h=3,
|
||||
pooled_w=3,
|
||||
sampling_ratio=0,
|
||||
)
|
||||
workspace.FeedBlob("feature", _feature)
|
||||
workspace.FeedBlob("rois", _rois)
|
||||
workspace.RunOperatorOnce(ref_op)
|
||||
return workspace.FetchBlob("roi_feature")
|
||||
|
||||
roi_feature_ref = roi_align_ref(feature, rois)
|
||||
roi_feature = torch.ops._caffe2.RoIAlignRotated(
|
||||
torch.Tensor(feature).to(device),
|
||||
torch.Tensor(rois).to(device),
|
||||
order="NCHW",
|
||||
spatial_scale=1.0,
|
||||
pooled_h=3,
|
||||
pooled_w=3,
|
||||
sampling_ratio=0,
|
||||
aligned=False,
|
||||
)
|
||||
torch.testing.assert_allclose(roi_feature_ref, roi_feature.cpu())
|
||||
|
||||
def test_roi_align_rotated_cpu(self):
|
||||
self._test_roi_align_rotated(device="cpu")
|
||||
|
||||
@unittest.skipIf(not workspace.has_cuda_support, "No cuda support")
|
||||
def test_roi_align_rotated_cuda(self):
|
||||
self._test_roi_align_rotated(device="cuda")
|
||||
|
||||
@given(roi_counts=st.lists(st.integers(0, 5), min_size=1, max_size=10))
|
||||
def test_collect_and_distribute_fpn_rpn_proposals_op(self, roi_counts):
|
||||
batch_size = len(roi_counts)
|
||||
|
||||
Reference in New Issue
Block a user