diff --git a/caffe2/operators/gather_ranges_to_dense_op.h b/caffe2/operators/gather_ranges_to_dense_op.h index 98fe0e678dc..c1dd5a52700 100644 --- a/caffe2/operators/gather_ranges_to_dense_op.h +++ b/caffe2/operators/gather_ranges_to_dense_op.h @@ -166,37 +166,41 @@ class GatherRangesToDenseOp final : public Operator { // Check whether the empty and mismatch ratio exceeded the threshold. totalRanges_ += batchSize; for (int j = 0; j < OutputSize(); ++j) { - CAFFE_ENFORCE_GT( - std::max(totalRanges_, minObservation_) * maxMismatchedRatio_, - mismatchedRanges_[j], - "Ratio of range length mismatch for feature at index ", - j, - " is ", - (static_cast(mismatchedRanges_[j]) / - static_cast(totalRanges_)), - " (", - mismatchedRanges_[j], - "/", - totalRanges_, - ") which exceeds ", - maxMismatchedRatio_); + // Only check when the ratio is not set to allow all mismatches. + if (maxMismatchedRatio_ < 1.0) { + CAFFE_ENFORCE_GE( + std::max(totalRanges_, minObservation_) * maxMismatchedRatio_, + mismatchedRanges_[j], + "Ratio of range length mismatch for feature at index ", + j, + " is ", + (static_cast(mismatchedRanges_[j]) / + static_cast(totalRanges_)), + " (", + mismatchedRanges_[j], + "/", + totalRanges_, + ") which exceeds ", + maxMismatchedRatio_); + } - // +0.5 to make sure when maxEmptyRatio_ is 1, this enforce will always be - // satisfied. - CAFFE_ENFORCE_GT( - std::max(totalRanges_, minObservation_) * maxEmptyRatio_ + 0.5, - emptyRanges_[j], - "Ratio of empty ranges for feature at index ", - j, - " is ", - (static_cast(emptyRanges_[j]) / - static_cast(totalRanges_)), - " (", - emptyRanges_[j], - "/", - totalRanges_, - ") which exceeds ", - maxEmptyRatio_); + // Only check when the ratio is not set to allow all examples to be empty. + if (maxEmptyRatio_ < 1.0) { + CAFFE_ENFORCE_GE( + std::max(totalRanges_, minObservation_) * maxEmptyRatio_, + emptyRanges_[j], + "Ratio of empty ranges for feature at index ", + j, + " is ", + (static_cast(emptyRanges_[j]) / + static_cast(totalRanges_)), + " (", + emptyRanges_[j], + "/", + totalRanges_, + ") which exceeds ", + maxEmptyRatio_); + } } return true; diff --git a/caffe2/python/operator_test/gather_ranges_op_test.py b/caffe2/python/operator_test/gather_ranges_op_test.py index 621c0521f17..a8ec8c35719 100644 --- a/caffe2/python/operator_test/gather_ranges_op_test.py +++ b/caffe2/python/operator_test/gather_ranges_op_test.py @@ -239,7 +239,7 @@ class TestGatherRanges(serial.SerializedTestCase): workspace.FeedBlob("key", key) def getOpWithThreshold( - min_observation=2, max_mismatched_ratio=0.6, max_empty_ratio=None + min_observation=2, max_mismatched_ratio=0.5, max_empty_ratio=None ): return core.CreateOperator( "GatherRangesToDense", @@ -254,12 +254,12 @@ class TestGatherRanges(serial.SerializedTestCase): workspace.RunOperatorOnce(getOpWithThreshold()) workspace.RunOperatorOnce( - getOpWithThreshold(max_mismatched_ratio=0.4, min_observation=50) + getOpWithThreshold(max_mismatched_ratio=0.3, min_observation=50) ) with self.assertRaises(RuntimeError): workspace.RunOperatorOnce( - getOpWithThreshold(max_mismatched_ratio=0.4, min_observation=5) + getOpWithThreshold(max_mismatched_ratio=0.3, min_observation=5) ) with self.assertRaises(RuntimeError):