[fix][tiny][caffe2] Avoid triggering errors when allow ratio is 100% (#34757)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34757

Reviewed By: Wakeupbuddy

Differential Revision: D20451255

fbshipit-source-id: 07997cf31dba653b61d082ec3f28357c3b90c4eb
This commit is contained in:
Xinyi Zhang
2020-03-16 11:33:05 -07:00
committed by Facebook GitHub Bot
parent 24c9e61e79
commit 99b91ee2ad
2 changed files with 37 additions and 33 deletions

View File

@@ -166,37 +166,41 @@ class GatherRangesToDenseOp final : public Operator<Context> {
// Check whether the empty and mismatch ratio exceeded the threshold.
totalRanges_ += batchSize;
for (int j = 0; j < OutputSize(); ++j) {
CAFFE_ENFORCE_GT(
std::max(totalRanges_, minObservation_) * maxMismatchedRatio_,
mismatchedRanges_[j],
"Ratio of range length mismatch for feature at index ",
j,
" is ",
(static_cast<double>(mismatchedRanges_[j]) /
static_cast<double>(totalRanges_)),
" (",
mismatchedRanges_[j],
"/",
totalRanges_,
") which exceeds ",
maxMismatchedRatio_);
// Only check when the ratio is not set to allow all mismatches.
if (maxMismatchedRatio_ < 1.0) {
CAFFE_ENFORCE_GE(
std::max(totalRanges_, minObservation_) * maxMismatchedRatio_,
mismatchedRanges_[j],
"Ratio of range length mismatch for feature at index ",
j,
" is ",
(static_cast<double>(mismatchedRanges_[j]) /
static_cast<double>(totalRanges_)),
" (",
mismatchedRanges_[j],
"/",
totalRanges_,
") which exceeds ",
maxMismatchedRatio_);
}
// +0.5 to make sure when maxEmptyRatio_ is 1, this enforce will always be
// satisfied.
CAFFE_ENFORCE_GT(
std::max(totalRanges_, minObservation_) * maxEmptyRatio_ + 0.5,
emptyRanges_[j],
"Ratio of empty ranges for feature at index ",
j,
" is ",
(static_cast<double>(emptyRanges_[j]) /
static_cast<double>(totalRanges_)),
" (",
emptyRanges_[j],
"/",
totalRanges_,
") which exceeds ",
maxEmptyRatio_);
// Only check when the ratio is not set to allow all examples to be empty.
if (maxEmptyRatio_ < 1.0) {
CAFFE_ENFORCE_GE(
std::max(totalRanges_, minObservation_) * maxEmptyRatio_,
emptyRanges_[j],
"Ratio of empty ranges for feature at index ",
j,
" is ",
(static_cast<double>(emptyRanges_[j]) /
static_cast<double>(totalRanges_)),
" (",
emptyRanges_[j],
"/",
totalRanges_,
") which exceeds ",
maxEmptyRatio_);
}
}
return true;

View File

@@ -239,7 +239,7 @@ class TestGatherRanges(serial.SerializedTestCase):
workspace.FeedBlob("key", key)
def getOpWithThreshold(
min_observation=2, max_mismatched_ratio=0.6, max_empty_ratio=None
min_observation=2, max_mismatched_ratio=0.5, max_empty_ratio=None
):
return core.CreateOperator(
"GatherRangesToDense",
@@ -254,12 +254,12 @@ class TestGatherRanges(serial.SerializedTestCase):
workspace.RunOperatorOnce(getOpWithThreshold())
workspace.RunOperatorOnce(
getOpWithThreshold(max_mismatched_ratio=0.4, min_observation=50)
getOpWithThreshold(max_mismatched_ratio=0.3, min_observation=50)
)
with self.assertRaises(RuntimeError):
workspace.RunOperatorOnce(
getOpWithThreshold(max_mismatched_ratio=0.4, min_observation=5)
getOpWithThreshold(max_mismatched_ratio=0.3, min_observation=5)
)
with self.assertRaises(RuntimeError):