mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[fix][tiny][caffe2] Avoid triggering errors when allow ratio is 100% (#34757)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34757 Reviewed By: Wakeupbuddy Differential Revision: D20451255 fbshipit-source-id: 07997cf31dba653b61d082ec3f28357c3b90c4eb
This commit is contained in:
committed by
Facebook GitHub Bot
parent
24c9e61e79
commit
99b91ee2ad
@@ -166,37 +166,41 @@ class GatherRangesToDenseOp final : public Operator<Context> {
|
||||
// Check whether the empty and mismatch ratio exceeded the threshold.
|
||||
totalRanges_ += batchSize;
|
||||
for (int j = 0; j < OutputSize(); ++j) {
|
||||
CAFFE_ENFORCE_GT(
|
||||
std::max(totalRanges_, minObservation_) * maxMismatchedRatio_,
|
||||
mismatchedRanges_[j],
|
||||
"Ratio of range length mismatch for feature at index ",
|
||||
j,
|
||||
" is ",
|
||||
(static_cast<double>(mismatchedRanges_[j]) /
|
||||
static_cast<double>(totalRanges_)),
|
||||
" (",
|
||||
mismatchedRanges_[j],
|
||||
"/",
|
||||
totalRanges_,
|
||||
") which exceeds ",
|
||||
maxMismatchedRatio_);
|
||||
// Only check when the ratio is not set to allow all mismatches.
|
||||
if (maxMismatchedRatio_ < 1.0) {
|
||||
CAFFE_ENFORCE_GE(
|
||||
std::max(totalRanges_, minObservation_) * maxMismatchedRatio_,
|
||||
mismatchedRanges_[j],
|
||||
"Ratio of range length mismatch for feature at index ",
|
||||
j,
|
||||
" is ",
|
||||
(static_cast<double>(mismatchedRanges_[j]) /
|
||||
static_cast<double>(totalRanges_)),
|
||||
" (",
|
||||
mismatchedRanges_[j],
|
||||
"/",
|
||||
totalRanges_,
|
||||
") which exceeds ",
|
||||
maxMismatchedRatio_);
|
||||
}
|
||||
|
||||
// +0.5 to make sure when maxEmptyRatio_ is 1, this enforce will always be
|
||||
// satisfied.
|
||||
CAFFE_ENFORCE_GT(
|
||||
std::max(totalRanges_, minObservation_) * maxEmptyRatio_ + 0.5,
|
||||
emptyRanges_[j],
|
||||
"Ratio of empty ranges for feature at index ",
|
||||
j,
|
||||
" is ",
|
||||
(static_cast<double>(emptyRanges_[j]) /
|
||||
static_cast<double>(totalRanges_)),
|
||||
" (",
|
||||
emptyRanges_[j],
|
||||
"/",
|
||||
totalRanges_,
|
||||
") which exceeds ",
|
||||
maxEmptyRatio_);
|
||||
// Only check when the ratio is not set to allow all examples to be empty.
|
||||
if (maxEmptyRatio_ < 1.0) {
|
||||
CAFFE_ENFORCE_GE(
|
||||
std::max(totalRanges_, minObservation_) * maxEmptyRatio_,
|
||||
emptyRanges_[j],
|
||||
"Ratio of empty ranges for feature at index ",
|
||||
j,
|
||||
" is ",
|
||||
(static_cast<double>(emptyRanges_[j]) /
|
||||
static_cast<double>(totalRanges_)),
|
||||
" (",
|
||||
emptyRanges_[j],
|
||||
"/",
|
||||
totalRanges_,
|
||||
") which exceeds ",
|
||||
maxEmptyRatio_);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
@@ -239,7 +239,7 @@ class TestGatherRanges(serial.SerializedTestCase):
|
||||
workspace.FeedBlob("key", key)
|
||||
|
||||
def getOpWithThreshold(
|
||||
min_observation=2, max_mismatched_ratio=0.6, max_empty_ratio=None
|
||||
min_observation=2, max_mismatched_ratio=0.5, max_empty_ratio=None
|
||||
):
|
||||
return core.CreateOperator(
|
||||
"GatherRangesToDense",
|
||||
@@ -254,12 +254,12 @@ class TestGatherRanges(serial.SerializedTestCase):
|
||||
workspace.RunOperatorOnce(getOpWithThreshold())
|
||||
|
||||
workspace.RunOperatorOnce(
|
||||
getOpWithThreshold(max_mismatched_ratio=0.4, min_observation=50)
|
||||
getOpWithThreshold(max_mismatched_ratio=0.3, min_observation=50)
|
||||
)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
workspace.RunOperatorOnce(
|
||||
getOpWithThreshold(max_mismatched_ratio=0.4, min_observation=5)
|
||||
getOpWithThreshold(max_mismatched_ratio=0.3, min_observation=5)
|
||||
)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
||||
Reference in New Issue
Block a user